Commit bb1f8082 authored by root's avatar root
Browse files

Merge remote-tracking branch 'origin/develop' into myamlak/cgemm

parents 97ac5007 82d7d993
...@@ -23,75 +23,86 @@ ...@@ -23,75 +23,86 @@
* SOFTWARE. * SOFTWARE.
* *
*******************************************************************************/ *******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp" #include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation>
__global__ void __global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m,
const WorkspaceDesc_M_K workspace_desc_m_k, const InElementwiseOperation in_elementwise_op,
const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, index_t block_group_size,
index_t block_group_size, index_t num_k_block_tile_iteration,
index_t num_k_block_tile_iteration, AccDataType alpha,
const InDataType* const __restrict__ p_src_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType* const __restrict__ p_ws_values_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
(void)p_in_index_global;
(void)p_out_index_global;
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, block_group_size,
num_k_block_tile_iteration, num_k_block_tile_iteration,
p_src_global, alpha,
p_ws_values_global, p_in_value_global,
p_ws_indices_global); beta,
p_out_value_global);
} }
else else
{ {
GridwiseReduction::RunWithIndex(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, num_k_block_tile_iteration,
num_k_block_tile_iteration, alpha,
p_src_global, p_in_value_global,
p_ws_values_global, p_in_index_global,
p_ws_indices_global); beta,
p_out_value_global,
p_out_index_global);
}; };
}; };
template <typename InDataType, template <typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
...@@ -101,14 +112,13 @@ template <typename InDataType, ...@@ -101,14 +112,13 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce struct GridwiseReduction_mk_to_m_multiblock
{ {
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size, index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{ {
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
// Each block executes multiple parallel reductions on the LDS, and due to the using of constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}( static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_store = if(block_group_size == 0 && !float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc, threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_buf); out_global_val_buf);
} }
}; };
template <bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType, PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
(void)acc_elementwise_op; (void)in_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf = const auto zeroVal = ReduceOperation::GetReductionZeroVal();
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, false>(
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, in_grid_desc_m_k,
block_local_id * reduceSizePerBlock + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t indexOffset = block_local_id * reduceSizePerBlock;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
index_t reducedTiles = 0; constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values index_t reducedTiles = 0;
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation if constexpr(HaveIndexInput)
in_elementwise_op(in_thread_val_buf(Number<offset>{}), {
in_thread_val_buf(Number<offset>{})); auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
AccDataType tmpValue = zeroVal; threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
IndexDataType tmpIndex = 0; threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue, reducedTiles++;
in_thread_val_buf[Number<offset>{}], } while(reducedTiles < num_k_block_tile_iteration);
tmpIndex, }
in_thread_idx_buf[Number<offset>{}]); else
{
index_t indexOffset = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
BlockwiseReduceWithIndex::Reduce( threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
};
indexOffset += K_BlockTileSize; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
reducedTiles++; static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
} while(reducedTiles < num_k_block_tile_iteration); if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( accu_value_buf(I) *= alpha;
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); }
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_val_store = if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
auto threadwise_workspace_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc, threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_val_buf); out_global_val_buf);
threadwise_workspace_idx_store.Run(reduced_data_desc, threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_index_buf, accu_index_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_idx_buf); out_global_idx_buf);
} }
}; };
}; };
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename AccElementwiseOperation>
__global__ void
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_global,
p_out_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::AtomicAdd,
1,
true>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
};
} // namespace ck
#endif
...@@ -37,7 +37,8 @@ ...@@ -37,7 +37,8 @@
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, beta,
p_out_global, p_out_value_global);
p_indices_global);
} }
else else
{ {
GridwiseReduction::RunWithIndices(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, p_in_index_global,
p_out_global, beta,
p_indices_global); p_out_value_global,
p_out_index_global);
}; };
}; };
...@@ -91,11 +93,9 @@ template <typename InDataType, ...@@ -91,11 +93,9 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool BetaIsZero,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global)
IndexDataType* const __restrict__ p_indices_global)
{ {
using ThreadwiseReduce = ThreadwiseReduction<AccDataType, using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K, ThreadReduceSrcDesc_M_K,
...@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
(void)p_indices_global;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_val_buf =
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0; index_t reducedLength = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, true>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
true>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; dst_global_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
dst_global_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_store = auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, OutDataType,
OutDataType, decltype(reduced_data_desc),
decltype(reduced_data_desc), OutGridDesc_M,
OutGridDesc_M, PassThroughOp,
PassThroughOp, Sequence<MThreadSliceSize>,
Sequence<MThreadSliceSize>, Sequence<0>,
Sequence<0>, 0,
0, OutDstVectorSize,
OutDstVectorSize, OutMemoryDataOperation,
InMemoryDataOperationEnum::Set, 1,
1, false>(
false>( out_grid_desc_m,
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize),
make_multi_index(thread_global_1d_id * MThreadSliceSize), PassThroughOp{});
PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k, template <bool HaveIndexInput>
const OutGridDesc_M& out_grid_desc_m, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const InElementwiseOperation& in_elementwise_op, const OutGridDesc_M& out_grid_desc_m,
const AccElementwiseOperation& acc_elementwise_op, const InElementwiseOperation& in_elementwise_op,
AccDataType alpha, const AccElementwiseOperation& acc_elementwise_op,
const InDataType* const __restrict__ p_in_global, AccDataType alpha,
AccDataType beta, const InDataType* const __restrict__ p_in_value_global,
OutDataType* const __restrict__ p_out_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType, using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
...@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_val_buf =
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t indexStart = 0; index_t indexStart = 0;
index_t reducedLength = 0; index_t reducedLength = 0;
do if constexpr(HaveIndexInput)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, auto threadwise_src_idx_load =
in_global_buf, ThreadwiseTensorSliceTransfer_v2<IndexDataType,
thread_buffer_desc, IndexDataType,
make_tuple(I0, I0), InGridDesc_M_K,
in_thread_val_buf); decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { ThreadwiseReduceWithIndex::Reduce(
// do element-wise pre-reduction operation in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK(); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
in_elementwise_op(in_thread_val_buf(Number<offset>{}), indexStart += KThreadSliceSize;
in_thread_val_buf(Number<offset>{})); reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
}
else
{
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
}); });
});
ThreadwiseReduceWithIndex::Reduce( ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
};
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, false>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
false>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; out_global_val_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
out_global_val_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise, ...@@ -11,138 +11,140 @@ template <typename GridwiseBinEltwise,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename GridDesc_M0, typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor> typename ElementwiseFunctor>
__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global, __global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global, const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global, CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0, const AGridDesc_M a_grid_desc_m,
const GridDesc_M0 b_grid_desc_m0, const BGridDesc_M b_grid_desc_m,
const GridDesc_M0 c_grid_desc_m0, const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
GridwiseBinEltwise::Run(p_a_global, GridwiseBinEltwise::Run(
p_b_global, p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor);
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
} }
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename ComputeDataType, typename ComputeDataType,
typename GridDesc_M0, typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename ElementwiseFunctor, typename ElementwiseFunctor,
index_t ScalarPerVector> index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector>
struct GridwiseBinaryElementwise_1D struct GridwiseBinaryElementwise_1D
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 = static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex() static __device__ auto CalculateElementwiseIndex()
{ {
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector); return make_multi_index(global_thread_id * MPerThread);
} }
__device__ static void Run(const ADataType* __restrict__ p_a_global, __device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global, const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global, CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0, const AGridDesc_M a_grid_desc_m,
const GridDesc_M0 b_grid_desc_m0, const BGridDesc_M b_grid_desc_m,
const GridDesc_M0 c_grid_desc_m0, const CGridDesc_M c_grid_desc_m,
const ElementwiseFunctor functor) const ElementwiseFunctor functor)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize()); p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize()); p_b_global, b_grid_desc_m.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize()); p_c_global, c_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf; StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex(); const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load = auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType, ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, AGridDesc_M,
decltype(thread_desc_m0), decltype(thread_desc_m),
Sequence<ScalarPerVector>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
ScalarPerVector, AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_store_global_offset}; false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load = auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType, ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType, ComputeDataType,
GridDesc_M0, BGridDesc_M,
decltype(thread_desc_m0), decltype(thread_desc_m),
Sequence<ScalarPerVector>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // SrcVectorDim 0, // SrcVectorDim
ScalarPerVector, BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_store_global_offset}; false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_write = auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType, ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType, CDataType,
decltype(thread_desc_m0), decltype(thread_desc_m),
GridDesc_M0, CGridDesc_M,
PassThrough, PassThrough,
Sequence<ScalarPerVector>, // SliceLengths Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder Sequence<0>, // DimAccessOrder
0, // DstVectorDim 0, // DstVectorDim
ScalarPerVector, CScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
false>{ false>{
c_grid_desc_m0, thread_store_global_offset, PassThrough{}}; c_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size(); const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size(); const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0); const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step); const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = m0 / (loop_step); index_t num_iter = M / (loop_step);
do do
{ {
// read and process ScalarPerVector elements // read and process MPerThread elements
a_global_load.Run( a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run( b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
static_for<0, ScalarPerVector, 1>{}([&](auto m) { static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}), functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}), a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{})); b_thread_buf(Number<offset>{}));
}); });
c_global_write.Run(thread_desc_m0, c_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf, c_thread_buf,
c_grid_desc_m0, c_grid_desc_m,
c_global_buf); c_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index);
} while(--num_iter); } while(--num_iter);
} }
}; };
......
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP #pragma once
#define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_dl_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AK0M0M1K1GridDesc, typename AGridDesc_K0_M0_M1_K1,
typename BK0N0N1K1GridDesc, typename BGridDesc_K0_N0_N1_K1,
typename CM0M10M11N0N10N11GridDesc, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename CBlockIdToM0N0BlockClusterAdaptor, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dlops_v1r3( kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, const Block2CTileMap block_2_ctile_map)
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -43,10 +43,10 @@ __global__ void ...@@ -43,10 +43,10 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_k0_m0_m1_k1_grid_desc, a_grid_desc_k0_m0_m1_k1,
b_k0_n0_n1_k1_grid_desc, b_grid_desc_k0_n0_n1_k1,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_desc_m0_m10_m11_n0_n10_n11,
cblockid_to_m0_n0_block_cluster_adaptor, block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -56,12 +56,12 @@ template <index_t BlockSize, ...@@ -56,12 +56,12 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AGridDesc_K0_M_K1,
typename BK0NK1GridDesc, typename BGridDesc_K0_N_K1,
typename CMNGridDesc, typename CGridDesc_M_N,
index_t MPerBlockM1, index_t MPerBlock,
index_t NPerBlockN1, index_t NPerBlock,
index_t KPerBlock, index_t K0PerBlock,
index_t M1PerThreadM111, index_t M1PerThreadM111,
index_t N1PerThreadN111, index_t N1PerThreadN111,
index_t KPerThread, index_t KPerThread,
...@@ -83,13 +83,8 @@ template <index_t BlockSize, ...@@ -83,13 +83,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
typename AGridStepHacks, struct GridwiseGemmDl_km_kn_mn_v1r3
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CMNGridDesc& c_m_n_grid_desc) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{ {
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop; return has_main_k_block_loop;
} }
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{ {
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop; return has_double_tail_k_block_loop;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{ {
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc = const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_k0_m_k1_grid_desc, transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc; return a_grid_desc_k0_m0_m1_k1;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{ {
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc = const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_k0_n_k1_grid_desc, transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc; return b_grid_desc_k0_n0_n1_k1;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlockN1>{}; constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
...@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_m_n_grid_desc, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc; return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_m_n_grid_desc.GetLength(I1); c_grid_desc_m_n);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); using CGridDesc_M0_M10_M11_N0_N10_N11 =
using CBlockIdToM0N0BlockClusterAdaptor = decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
...@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx =
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment // TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() && b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!"); "wrong!");
...@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc), remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
...@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, false,
true>(a_k0_m0_m1_k1_grid_desc, true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0), make_multi_index(0, im0, 0, 0),
a_k0_m0_m1_k1_block_desc, a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc), remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
...@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, false,
true>(b_k0_n0_n1_k1_grid_desc, true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0), make_multi_index(0, in0, 0, 0),
b_k0_n0_n1_k1_block_desc, b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS // b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
...@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, // Initialize C
decltype(c_m10_m11_n10_n11_thread_desc), c_thread_buf.Clear();
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step, a_block_slice_copy_step);
AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf, a_block_even_buf,
b_block_even_buf, b_block_even_buf,
c_thread_buf); c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step, a_block_slice_copy_step);
AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
__syncthreads(); block_sync_lds();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
__syncthreads(); block_sync_lds();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
...@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
...@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1], c_m10_m11_n10_n11_thread_tensor_lengths[I1],
...@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc, true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0, make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0, in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc, ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf, c_grid_buf);
CGridStepHacks{});
} }
} }
}; };
} // namespace ck } // namespace ck
#endif
...@@ -306,7 +306,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -306,7 +306,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
......
...@@ -259,7 +259,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -259,7 +259,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
......
#pragma once #pragma once
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -287,11 +288,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -287,11 +288,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01); c_grid_desc_m_n);
} }
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
......
...@@ -265,10 +265,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -265,10 +265,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{ {
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>( return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
c_m_n_grid_desc, M01, N01, KBatch); c_m_n_grid_desc, 8, KBatch);
} }
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
......
...@@ -239,10 +239,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -239,10 +239,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{ {
return BlockToCTileMap_KSplit_M00_N00_M01_N01<MPerBlock, NPerBlock, CMNGridDesc>( return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CMNGridDesc>(
c_m_n_grid_desc, M01, N01, KBatch); c_m_n_grid_desc, 8, KBatch);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
......
...@@ -300,11 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -300,11 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
......
...@@ -309,11 +309,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -309,11 +309,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
......
...@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
{ {
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n, M01, N01); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype( remove_cvref_t<decltype(
......
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP #pragma once
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -25,9 +23,9 @@ template <typename FloatA, ...@@ -25,9 +23,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -124,9 +122,9 @@ template <typename FloatA, ...@@ -124,9 +122,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP #pragma once
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -28,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x) ...@@ -28,6 +28,12 @@ __device__ float atomic_add<float>(float* p_dst, const float& x)
return atomicAdd(p_dst, x); return atomicAdd(p_dst, x);
} }
template <>
__device__ double atomic_add<double>(double* p_dst, const double& x)
{
return atomicAdd(p_dst, x);
}
template <> template <>
__device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
{ {
...@@ -45,6 +51,23 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x) ...@@ -45,6 +51,23 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
return vy.template AsType<float2_t>()[I0]; return vy.template AsType<float2_t>()[I0];
} }
template <>
__device__ double2_t atomic_add<double2_t>(double2_t* p_dst, const double2_t& x)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const vector_type<double, 2> vx{x};
vector_type<double, 2> vy{0};
vy.template AsType<double>()(I0) =
atomicAdd(c_style_pointer_cast<double*>(p_dst), vx.template AsType<double>()[I0]);
vy.template AsType<double>()(I1) =
atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, vx.template AsType<double>()[I1]);
return vy.template AsType<double2_t>()[I0];
}
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for // instantiate this template. The purpose is to make the implementation of atomic_max explicit for
......
#ifndef CK_INNER_PRODUCT_HPP #pragma once
#define CK_INNER_PRODUCT_HPP
#include "data_type.hpp" #include "data_type.hpp"
namespace ck { namespace ck {
...@@ -138,7 +136,7 @@ template <> ...@@ -138,7 +136,7 @@ template <>
__device__ void __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c) inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{ {
#if defined(CK_USE_DOT4_I32_I8) #if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \ asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \
...@@ -202,4 +200,3 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t ...@@ -202,4 +200,3 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
} }
} // namespace ck } // namespace ck
#endif
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#ifndef CK_REDUCTION_OPERATOR_HPP #ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP #define CK_REDUCTION_OPERATOR_HPP
#include "common_header.hpp" #include "config.hpp"
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -41,12 +42,10 @@ namespace reduce { ...@@ -41,12 +42,10 @@ namespace reduce {
// when operated against them, and the concept is similar to zero vector in // when operated against them, and the concept is similar to zero vector in
// vector space // vector space
// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). // (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
// 2) indexable -- boolean value indicating whether indices of the operated elements could be // 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
// recorded. Usually, Min/Max operator could // operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() --
// need to record the indices of elements. For operator like Add/Mul, no need to // the first argument of the operator must be both an input & output, and the corresponding variable
// record the indices. // usually stores
// 3) operator() -- the first argument of the operator must be both an input & output, and the
// corresponding variable usually stores
// the accumulated result of many operator() calls; the second argument is only an // the accumulated result of many operator() calls; the second argument is only an
// input. For indexable binary // input. For indexable binary
// operator, the second version of operator() has third argument (which is an // operator, the second version of operator() has third argument (which is an
...@@ -62,6 +61,13 @@ struct Add ...@@ -62,6 +61,13 @@ struct Add
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
}; };
...@@ -72,6 +78,12 @@ struct Mul ...@@ -72,6 +78,12 @@ struct Mul
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); }; __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
}; };
...@@ -85,6 +97,13 @@ struct Max ...@@ -85,6 +97,13 @@ struct Max
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
if(a < b) if(a < b)
...@@ -111,6 +130,13 @@ struct Min ...@@ -111,6 +130,13 @@ struct Min
return NumericLimits<T>::Max(); return NumericLimits<T>::Max();
}; };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
if(a > b) if(a > b)
...@@ -134,6 +160,13 @@ struct AMax ...@@ -134,6 +160,13 @@ struct AMax
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
// ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set;
};
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
if(a < b) if(a < b)
...@@ -150,6 +183,17 @@ struct AMax ...@@ -150,6 +183,17 @@ struct AMax
} }
}; };
template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
T result = ck::type_convert<T>(0.0f);
if(operation == InMemoryDataOperationEnum::AtomicMax)
result = ck::NumericLimits<T>::Lowest();
return (result);
};
}; // end of namespace reduce }; // end of namespace reduce
} // end of namespace ck } // end of namespace ck
......
...@@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N> ...@@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
{ {
return base::operator()(i); return base::operator()(i);
} }
__host__ __device__ void Clear()
{
static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; });
}
}; };
// static buffer for vector // static buffer for vector
...@@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector ...@@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector
__host__ __device__ void Clear() __host__ __device__ void Clear()
{ {
const index_t numScalars = NumOfVector * ScalarPerVector; constexpr index_t NumScalars = NumOfVector * ScalarPerVector;
static_for<0, Number<numScalars>{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); });
} }
}; };
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_COMMON_UTIL_HPP
#define GUARD_HOST_COMMON_UTIL_HPP
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include "config.hpp"
namespace ck {
namespace host_common {
template <typename T>
static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
{
std::ofstream outFile(fileName, std::ios::binary);
if(outFile)
{
outFile.write(reinterpret_cast<char*>(data), dataNumItems * sizeof(T));
outFile.close();
std::cout << "Write output to file " << fileName << std::endl;
}
else
{
std::cout << "Could not open file " << fileName << " for writing" << std::endl;
}
};
template <typename T>
static inline T getSingleValueFromString(const std::string& valueStr)
{
std::istringstream iss(valueStr);
T val;
iss >> val;
return (val);
};
template <typename T>
static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
{
std::string valuesStr(cstr_values);
std::vector<T> values;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = valuesStr.find(',', pos);
while(new_pos != std::string::npos)
{
const std::string sliceStr = valuesStr.substr(pos, new_pos - pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
pos = new_pos + 1;
new_pos = valuesStr.find(',', pos);
};
std::string sliceStr = valuesStr.substr(pos);
T val = getSingleValueFromString<T>(sliceStr);
values.push_back(val);
return (values);
}
}; // namespace host_common
}; // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment