Unverified Commit 82c8b9f8 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Improve Reduction kernel api (#152)

* Add ThreadwiseReduction functor as per-thread reduction api

* Using ThreadwiseReduce api and some change in using PartitionedBlockwiseReduction api to simply the kernels

* Add comments and remove useless declarations in the kernels

* Tiny updates
parent 64687816
...@@ -26,16 +26,20 @@ ...@@ -26,16 +26,20 @@
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP #ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP #define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
namespace ck { namespace ck {
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType, template <typename AccDataType,
index_t BlockSize, index_t BlockSize,
typename ThreadClusterLengths_M_K, typename ThreadClusterLengths_M_K,
...@@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction ...@@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>; using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType> template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData) __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{ {
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx = const auto thread_cluster_idx =
...@@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction ...@@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) { static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
...@@ -80,10 +91,10 @@ struct PartitionedBlockwiseReduction ...@@ -80,10 +91,10 @@ struct PartitionedBlockwiseReduction
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]); AccDataType opData1 = work_buffer[offset1];
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]); AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2); Accumulation::Calculate(opData1, opData2);
block_buffer(offset1) = type_convert<AccDataType>(opData1); work_buffer(offset1) = opData1;
} }
__syncthreads(); __syncthreads();
...@@ -91,10 +102,17 @@ struct PartitionedBlockwiseReduction ...@@ -91,10 +102,17 @@ struct PartitionedBlockwiseReduction
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_buffer[offset]); in_out_value = work_buffer[offset];
}; };
}; };
// clang-format off
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType, template <typename AccDataType,
typename IndexDataType, typename IndexDataType,
index_t BlockSize, index_t BlockSize,
...@@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex
// This interface accumulates on both data values and indices // This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType> template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce(BufferType& block_val_buffer, __device__ static void Reduce(BufferType& work_val_buffer,
IdxBufferType& block_idx_buffer, IdxBufferType& work_idx_buffer,
AccDataType& accuData, AccDataType& in_out_value,
IndexDataType& accuIndex) IndexDataType& in_out_index)
{ {
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
static_assert(is_same<typename IdxBufferType::type, IndexDataType>{},
"Buffer data type should be consistent as IndexDataType!");
constexpr auto cluster_len_shift = get_shift<BufferLength_K>(); constexpr auto cluster_len_shift = get_shift<BufferLength_K>();
const auto thread_cluster_idx = const auto thread_cluster_idx =
...@@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];
work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;
__syncthreads();
static_for<0, cluster_len_shift, 1>{}([&](auto I) { static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I(); constexpr index_t indOffset = 1 << I();
...@@ -145,14 +173,14 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -145,14 +173,14 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset)); make_tuple(0, indOffset));
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]); AccDataType opData1 = work_val_buffer[offset1];
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]); AccDataType opData2 = work_val_buffer[offset2];
IndexDataType currIndex1 = block_idx_buffer[offset1]; IndexDataType currIndex1 = work_idx_buffer[offset1];
IndexDataType currIndex2 = block_idx_buffer[offset2]; IndexDataType currIndex2 = work_idx_buffer[offset2];
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2); Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
block_val_buffer(offset1) = type_convert<AccDataType>(opData1); work_val_buffer(offset1) = opData1;
block_idx_buffer(offset1) = currIndex1; work_idx_buffer(offset1) = currIndex1;
} }
__syncthreads(); __syncthreads();
...@@ -160,9 +188,9 @@ struct PartitionedBlockwiseReductionWithIndex ...@@ -160,9 +188,9 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));
accuData = type_convert<AccDataType>(block_val_buffer[offset]); in_out_value = work_val_buffer[offset];
accuIndex = block_idx_buffer[offset]; in_out_index = work_idx_buffer[offset];
} };
}; };
}; // end of namespace ck }; // end of namespace ck
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -179,10 +180,10 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -179,10 +180,10 @@ struct GridwiseReduction_mk_to_m_blockwise
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
// Dim_K as the fastest one make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( using ThreadReduceDstDesc_M =
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
...@@ -216,14 +217,18 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -216,14 +217,18 @@ struct GridwiseReduction_mk_to_m_blockwise
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global; (void)p_ws_indices_global;
(void)p_indices_global; (void)p_indices_global;
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
...@@ -232,8 +237,8 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -232,8 +237,8 @@ struct GridwiseReduction_mk_to_m_blockwise
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
...@@ -285,38 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -285,38 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++; reducedTiles++;
} while(reducedTiles < toReduceTiles); } while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads(); static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -414,8 +407,8 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -414,8 +407,8 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)p_ws_indices_global; (void)p_ws_indices_global;
// LDS // LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
...@@ -426,15 +419,18 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -426,15 +419,18 @@ struct GridwiseReduction_mk_to_m_blockwise
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
...@@ -491,42 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -491,42 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_val_buf); in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values // initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) = in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J(); indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset)); in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice AccumulationWithIndex::Calculate(tmpValue,
AccumulationWithIndex::Calculate( in_thread_val_buf[Number<offset>{}],
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]); tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
}); });
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce( BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex); reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
...@@ -535,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -535,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise
reducedTiles++; reducedTiles++;
} while(reducedTiles < toReduceTiles); } while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -665,8 +654,8 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -665,8 +654,8 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)in_elementwise_op; (void)in_elementwise_op;
// LDS // LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
...@@ -681,10 +670,10 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -681,10 +670,10 @@ struct GridwiseReduction_mk_to_m_blockwise
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
...@@ -745,8 +734,6 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -745,8 +734,6 @@ struct GridwiseReduction_mk_to_m_blockwise
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
// index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
...@@ -771,42 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -771,42 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_idx_buf); in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal; AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice AccumulationWithIndex::Calculate(tmpValue,
AccumulationWithIndex::Calculate( in_thread_val_buf[Number<offset>{}],
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]); tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
}); });
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce( BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex); reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
// indexOffset += K_BlockTileSize;
reducedTiles++; reducedTiles++;
} while(reducedTiles < toReduceTiles); } while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -103,10 +104,10 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -103,10 +104,10 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
// Dim_K as the fastest one make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( using ThreadReduceDstDesc_M =
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType, using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize, BlockSize,
...@@ -115,6 +116,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -115,6 +116,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -138,15 +145,15 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -138,15 +145,15 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
...@@ -198,42 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -198,42 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its // Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a // reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load, // consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions. // each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}(
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -121,10 +122,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -121,10 +122,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
// Dim_K as the fastest one make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( using ThreadReduceDstDesc_M =
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
...@@ -151,8 +152,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -151,8 +152,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using Accumulation = using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global; (void)p_ws_indices_global;
(void)acc_elementwise_op; (void)acc_elementwise_op;
...@@ -160,7 +164,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -160,7 +164,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
...@@ -169,8 +173,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -169,8 +173,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
...@@ -222,20 +226,17 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -222,20 +226,17 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++; reducedTiles++;
...@@ -243,16 +244,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -243,16 +244,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
// Each block executes multiple parallel reductions on the LDS, and due to the using of // Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions. // vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}(
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
...@@ -315,8 +308,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -315,8 +308,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_block_reduce_idx_buffer[BlockSize]; __shared__ index_t p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf = const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
...@@ -327,10 +320,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -327,10 +320,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
...@@ -394,42 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -394,42 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_val_buf); in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values // initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) = in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J(); indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset)); in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice AccumulationWithIndex::Calculate(tmpValue,
AccumulationWithIndex::Calculate( in_thread_val_buf[Number<offset>{}],
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]); tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
}); });
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce( BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex); reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -110,6 +111,11 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -110,6 +111,11 @@ struct GridwiseReduction_mk_to_m_threadwise
using ThreadBufferDimAccessOrder = using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type; typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -124,9 +130,11 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -124,9 +130,11 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
using Accumulation = ThreadReduceSrcDesc_M_K,
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_indices_global; (void)p_indices_global;
...@@ -175,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -175,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
...@@ -200,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -200,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha; accu_value_buf(I) *= alpha;
}); });
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
if constexpr(!BetaIsZero) if constexpr(!BetaIsZero)
{ {
...@@ -266,10 +270,13 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -266,10 +270,13 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan, using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
ReduceOperation, IndexDataType,
AccDataType, ThreadReduceSrcDesc_M_K,
IndexDataType>; ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
...@@ -282,7 +289,13 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -282,7 +289,13 @@ struct GridwiseReduction_mk_to_m_threadwise
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
...@@ -322,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -322,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise
in_global_buf, in_global_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
});
// reduce on each thread-local slice in_elementwise_op(in_thread_val_buf(Number<offset>{}),
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { in_thread_val_buf(Number<offset>{}));
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
AccumulationWithIndex::Calculate(accu_value_buf(I),
in_thread_buf[offset],
accu_index_buf(I),
indexStart + J);
}); });
}); });
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
...@@ -355,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -355,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha; accu_value_buf(I) *= alpha;
}); });
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
if constexpr(!BetaIsZero) if constexpr(!BetaIsZero)
{ {
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#define CK_REDUCTION_FUNCTIONS_THREADWISE_HPP
#include "reduction_functions_accumulate.hpp"
namespace ck {
// Assume
// 1) SrcDesc is known at compile-time
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template <typename AccDataType,
typename SrcThreadDesc_M_K,
typename DstThreadDesc_M,
typename OpReduce,
bool PropagateNan>
struct ThreadwiseReduction
{
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType>
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
{
static_for<0, src_length_m, 1>{}([&](auto iM) {
constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
static_for<0, src_length_k, 1>{}([&](auto iK) {
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Accumulation::Calculate(dst_buf(Number<out_offset>{}), src_buf[Number<offset>{}]);
});
});
};
};
// Assume
// 1) SrcDesc is known at compile-time
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template <typename AccDataType,
typename IndexDataType,
typename SrcThreadDesc_M_K,
typename DstThreadDesc_M,
typename OpReduce,
bool PropagateNan>
struct ThreadwiseReductionWithIndex
{
static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{};
static constexpr auto dst_thread_desc_m = DstThreadDesc_M{};
static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
template <typename SrcValueBufferType,
typename SrcIndexBufferType,
typename DstValueBufferType,
typename DstIndexBufferType>
__device__ static void Reduce(const SrcValueBufferType& src_val_buf,
const SrcIndexBufferType& src_idx_buf,
DstValueBufferType& dst_val_buf,
DstIndexBufferType& dst_idx_buf)
{
static_for<0, src_length_m, 1>{}([&](auto iM) {
constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
static_for<0, src_length_k, 1>{}([&](auto iK) {
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
Accumulation::Calculate(dst_val_buf(Number<out_offset>{}),
src_val_buf[Number<offset>{}],
dst_idx_buf(Number<out_offset>{}),
src_idx_buf[Number<offset>{}]);
});
});
};
};
}; // end of namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment