Commit 4fea4251 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Rename GetZeroVal() to GetReductionZeroVal() in the kernels

parent 52ae56f8
...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
// LDS // LDS
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -243,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize]; __shared__ int block_indices_buffer[BlockBufferSize];
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -431,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize]; __shared__ int block_indices_buffer[BlockBufferSize];
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
(void)ws_indices_global; (void)ws_indices_global;
(void)indices_global; (void)indices_global;
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -204,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{ {
(void)ws_indices_global; (void)ws_indices_global;
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -348,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{ {
(void)origReduceLen; (void)origReduceLen;
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
(void)ws_indices_global; (void)ws_indices_global;
(void)indices_global; (void)indices_global;
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -215,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{ {
(void)ws_indices_global; (void)ws_indices_global;
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -373,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{ {
(void)origReduceLen; (void)origReduceLen;
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
......
...@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused (void)alpha; // unused
(void)beta; // unused (void)beta; // unused
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
// LDS // LDS
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
...@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused (void)alpha; // unused
(void)beta; // unused (void)beta; // unused
constexpr auto zeroVal = opReduce::GetZeroVal(); constexpr auto zeroVal = opReduce::GetReductionZeroVal();
// LDS // LDS
__shared__ compType p_in_block_values_buffer[BlockBufferSize]; __shared__ compType p_in_block_values_buffer[BlockBufferSize];
......
...@@ -56,7 +56,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -56,7 +56,7 @@ struct BlockwiseReduction_2d_block_buffer
Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData) Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData)
{ {
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
index_t offset; index_t offset;
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
...@@ -115,7 +115,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -115,7 +115,7 @@ struct BlockwiseReduction_2d_block_buffer
int& accuIndex) int& accuIndex)
{ {
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
if constexpr(blockIsOneRow) if constexpr(blockIsOneRow)
......
...@@ -62,7 +62,7 @@ struct WarpReduce ...@@ -62,7 +62,7 @@ struct WarpReduce
// This interface implementation uses HIP built-in device shuffling functions // This interface implementation uses HIP built-in device shuffling functions
__device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData) __device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
static_for<0, ThreadBufferLen, 1>{}( static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
...@@ -84,7 +84,7 @@ struct WarpReduce ...@@ -84,7 +84,7 @@ struct WarpReduce
// since for fp16, built-in shuffling functions is not provided by HIP // since for fp16, built-in shuffling functions is not provided by HIP
__device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData) __device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
static_for<0, ThreadBufferLen, 1>{}( static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
...@@ -138,7 +138,7 @@ struct WarpReduce ...@@ -138,7 +138,7 @@ struct WarpReduce
int& accuIndex, int& accuIndex,
int indexStart) int indexStart)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize; index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize;
...@@ -170,7 +170,7 @@ struct WarpReduce ...@@ -170,7 +170,7 @@ struct WarpReduce
int& accuIndex, int& accuIndex,
int indexStart) int indexStart)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
index_t thread_id = get_thread_local_1d_id(); index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize; index_t warpId = thread_id / warpSize;
...@@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput ...@@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput
compType& accuData, compType& accuData,
int& accuIndex) int& accuIndex)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
...@@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput ...@@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput
compType& accuData, compType& accuData,
int& accuIndex) int& accuIndex)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
index_t thread_id = get_thread_local_1d_id(); index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize; index_t warpId = thread_id / warpSize;
......
...@@ -35,8 +35,8 @@ namespace reduce { ...@@ -35,8 +35,8 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetZeroVal() -- the interface to return the "identity element" for the binary operator, // 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
// "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated with any of them. // when operated with any of them.
// 2) indexable -- boolean value indicating whether indices of the operated elements could be // 2) indexable -- boolean value indicating whether indices of the operated elements could be
...@@ -58,7 +58,7 @@ struct Add ...@@ -58,7 +58,7 @@ struct Add
{ {
using dataType = T; using dataType = T;
__device__ static constexpr T GetZeroVal() { return static_cast<T>(0.0f); }; __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
...@@ -70,7 +70,7 @@ struct Mul ...@@ -70,7 +70,7 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__device__ static constexpr T GetZeroVal() { return static_cast<T>(1.0f); }; __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
...@@ -82,7 +82,7 @@ struct Max ...@@ -82,7 +82,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__device__ static constexpr T GetZeroVal() { return NumericLimits<T>::lowest(); }; __device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::lowest(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -107,7 +107,7 @@ struct Min ...@@ -107,7 +107,7 @@ struct Min
{ {
using dataType = T; using dataType = T;
__device__ static constexpr T GetZeroVal() { return NumericLimits<T>::Max(); }; __device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Max(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -132,7 +132,7 @@ struct AMax ...@@ -132,7 +132,7 @@ struct AMax
{ {
using dataType = T; using dataType = T;
__device__ static constexpr T GetZeroVal() { return static_cast<T>(0.0f); }; __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -281,7 +281,7 @@ struct unary_sqrt<half_t> ...@@ -281,7 +281,7 @@ struct unary_sqrt<half_t>
// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their // The templated struct reduce_binary_operator maps the enum Ids of binary operators to their
// respective functor classes. // respective functor classes.
// The "GetZeroVal()" interface and boolean member "indexable" are also provided in // The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in
// reduce_binary_operactor for // reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels. // easier checking by the upper-layer codes in the kernels.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment