Commit 07a673c6 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into cpu_avx2

parents c0f698d5 ac0d8066
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,7 +25,7 @@ template <typename ALayout, ...@@ -24,7 +25,7 @@ template <typename ALayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -84,8 +85,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -84,8 +85,8 @@ struct DeviceGemm_Xdl_CShuffle
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both M and K // pad both M and K
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -108,8 +109,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -108,8 +109,8 @@ struct DeviceGemm_Xdl_CShuffle
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad M, but not K // pad M, but not K
assert(KRaw % AK1 == 0); assert(KRaw % AK1 == 0);
...@@ -125,8 +126,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -125,8 +126,8 @@ struct DeviceGemm_Xdl_CShuffle
return a_grid_desc_ak0_m_ak1; return a_grid_desc_ak0_m_ak1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad K, but not M // pad K, but not M
assert(K % AK1 == 0); assert(K % AK1 == 0);
...@@ -187,8 +188,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -187,8 +188,8 @@ struct DeviceGemm_Xdl_CShuffle
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
const auto KPad = K - KRaw; const auto KPad = K - KRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad both N and K // pad both N and K
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -211,8 +212,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -211,8 +212,8 @@ struct DeviceGemm_Xdl_CShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::MNPadding) GemmSpec == GemmSpecialization::MNPadding)
{ {
// pad N, but not K // pad N, but not K
assert(KRaw % BK1 == 0); assert(KRaw % BK1 == 0);
...@@ -228,8 +229,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -228,8 +229,8 @@ struct DeviceGemm_Xdl_CShuffle
return b_grid_desc_bk0_n_bk1; return b_grid_desc_bk0_n_bk1;
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad K, but not N // pad K, but not N
assert(K % BK1 == 0); assert(K % BK1 == 0);
...@@ -290,8 +291,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -290,8 +291,8 @@ struct DeviceGemm_Xdl_CShuffle
const auto MPad = M - MRaw; const auto MPad = M - MRaw;
const auto NPad = N - NRaw; const auto NPad = N - NRaw;
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpecialization == GemmSpecialization_t::MNKPadding) GemmSpec == GemmSpecialization::MNKPadding)
{ {
// pad M and N // pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw, return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
...@@ -300,8 +301,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -300,8 +301,8 @@ struct DeviceGemm_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpecialization == GemmSpecialization_t::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
// pad M, but not N // pad M, but not N
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -310,8 +311,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -310,8 +311,8 @@ struct DeviceGemm_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpecialization == GemmSpecialization_t::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
// pad N, but not M // pad N, but not M
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -340,7 +341,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -340,7 +341,7 @@ struct DeviceGemm_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -434,7 +435,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -434,7 +435,7 @@ struct DeviceGemm_Xdl_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, int /* nrepeat */ = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
#if 0 #if 0
{ {
...@@ -465,6 +466,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -465,6 +466,8 @@ struct DeviceGemm_Xdl_CShuffle
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
float ave_time = 0;
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
...@@ -480,6 +483,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -480,6 +483,8 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
if(nrepeat == 0)
{
launch_kernel(kernel, launch_kernel(kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -496,6 +501,26 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -496,6 +501,26 @@ struct DeviceGemm_Xdl_CShuffle
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
}
else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
...@@ -510,6 +535,8 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -510,6 +535,8 @@ struct DeviceGemm_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
if(nrepeat == 0)
{
launch_kernel(kernel, launch_kernel(kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
...@@ -525,8 +552,28 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -525,8 +552,28 @@ struct DeviceGemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else
{
ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
}
}
return 0; return ave_time;
} }
// polymorphic // polymorphic
......
...@@ -31,7 +31,7 @@ template <typename ADataType, ...@@ -31,7 +31,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK ...@@ -91,7 +91,7 @@ struct DeviceGemmXdlSplitK
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK ...@@ -136,7 +136,7 @@ struct DeviceGemmXdlSplitK
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK ...@@ -170,7 +170,7 @@ struct DeviceGemmXdlSplitK
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK ...@@ -209,7 +209,7 @@ struct DeviceGemmXdlSplitK
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK ...@@ -250,7 +250,7 @@ struct DeviceGemmXdlSplitK
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -31,7 +31,7 @@ template <typename ADataType, ...@@ -31,7 +31,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -93,7 +93,7 @@ struct DeviceGemmXdlSplitKCShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -138,7 +138,7 @@ struct DeviceGemmXdlSplitKCShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
...@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -172,7 +172,7 @@ struct DeviceGemmXdlSplitKCShuffle
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -211,7 +211,7 @@ struct DeviceGemmXdlSplitKCShuffle
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
...@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -253,7 +253,7 @@ struct DeviceGemmXdlSplitKCShuffle
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -27,7 +27,7 @@ template <typename ADataType, ...@@ -27,7 +27,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
GemmSpecialization_t GemmSpecialization, GemmSpecialization GemmSpec,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -81,7 +81,7 @@ struct DeviceGroupedGemmXdl ...@@ -81,7 +81,7 @@ struct DeviceGroupedGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
...@@ -120,7 +120,7 @@ struct DeviceGroupedGemmXdl ...@@ -120,7 +120,7 @@ struct DeviceGroupedGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -155,7 +155,7 @@ struct DeviceGroupedGemmXdl ...@@ -155,7 +155,7 @@ struct DeviceGroupedGemmXdl
} }
}(); }();
if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{ {
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
...@@ -187,7 +187,7 @@ struct DeviceGroupedGemmXdl ...@@ -187,7 +187,7 @@ struct DeviceGroupedGemmXdl
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, BGridDesc_K0_N_K1,
CGridDesc_M_N, CGridDesc_M_N,
......
...@@ -10,7 +10,7 @@ namespace ck { ...@@ -10,7 +10,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <ck::ReduceTensorOp_t ReduceOpId> template <ck::ReduceTensorOp ReduceOpId>
struct DevicePool2dFwd : public BaseOperator struct DevicePool2dFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator ...@@ -29,7 +29,7 @@ struct DevicePool2dFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::ReduceTensorOp_t ReduceOpId> template <ck::ReduceTensorOp ReduceOpId>
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>; using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
} // namespace device } // namespace device
......
...@@ -16,7 +16,7 @@ namespace device { ...@@ -16,7 +16,7 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp_t ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool NeedIndices, bool NeedIndices,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize, ck::index_t ReduceMThreadClusterSize,
...@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -181,7 +181,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
reduce_lowest_length_ = window_spatial_lengths[1]; reduce_lowest_length_ = window_spatial_lengths[1];
// TODO: is this correct? // TODO: is this correct?
if constexpr(ReduceOpId == ck::ReduceTensorOp_t::AVG) if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
{ {
ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
in_element_op_ = InElementwiseOperation{divider}; in_element_op_ = InElementwiseOperation{divider};
......
...@@ -5,7 +5,7 @@ namespace ck { ...@@ -5,7 +5,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
enum struct GemmSpecialization_t enum struct GemmSpecialization
{ {
Default, Default,
MPadding, MPadding,
......
...@@ -37,11 +37,11 @@ namespace ck { ...@@ -37,11 +37,11 @@ namespace ck {
// The boolean member "indexable" are also provided in reduce_binary_operactor for // The boolean member "indexable" are also provided in reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels. // easier checking by the upper-layer codes in the kernels.
template <typename T, ReduceTensorOp_t Op> template <typename T, ReduceTensorOp Op>
struct reduce_binary_operator; struct reduce_binary_operator;
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::ADD> struct reduce_binary_operator<T, ReduceTensorOp::ADD>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD> ...@@ -50,7 +50,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MUL> struct reduce_binary_operator<T, ReduceTensorOp::MUL>
{ {
using opType = reduce::Mul<T>; using opType = reduce::Mul<T>;
using dataType = T; using dataType = T;
...@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL> ...@@ -59,7 +59,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MIN> struct reduce_binary_operator<T, ReduceTensorOp::MIN>
{ {
using opType = reduce::Min<T>; using opType = reduce::Min<T>;
using dataType = T; using dataType = T;
...@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN> ...@@ -68,7 +68,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::MAX> struct reduce_binary_operator<T, ReduceTensorOp::MAX>
{ {
using opType = reduce::Max<T>; using opType = reduce::Max<T>;
using dataType = T; using dataType = T;
...@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX> ...@@ -77,7 +77,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX> struct reduce_binary_operator<T, ReduceTensorOp::AMAX>
{ {
using opType = reduce::AMax<T>; using opType = reduce::AMax<T>;
using dataType = T; using dataType = T;
...@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX> ...@@ -86,7 +86,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AVG> struct reduce_binary_operator<T, ReduceTensorOp::AVG>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG> ...@@ -95,7 +95,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1> struct reduce_binary_operator<T, ReduceTensorOp::NORM1>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1> ...@@ -104,7 +104,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2> struct reduce_binary_operator<T, ReduceTensorOp::NORM2>
{ {
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
...@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2> ...@@ -115,7 +115,7 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary // The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary
// functor classes. // functor classes.
// The two unary functors are called before and afer the Reduction is executed respectively // The two unary functors are called before and afer the Reduction is executed respectively
template <typename T, ReduceTensorOp_t Op, bool IsFirstReduce, bool IsLastReduce> template <typename T, ReduceTensorOp Op, bool IsFirstReduce, bool IsLastReduce>
struct reduce_unary_operator struct reduce_unary_operator
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
...@@ -123,42 +123,42 @@ struct reduce_unary_operator ...@@ -123,42 +123,42 @@ struct reduce_unary_operator
}; };
template <typename T, bool IsFirstReduce> template <typename T, bool IsFirstReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::AVG, IsFirstReduce, true> struct reduce_unary_operator<T, ReduceTensorOp::AVG, IsFirstReduce, true>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T, true>;
}; };
template <typename T, bool IsLastReduce> template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM1, true, IsLastReduce> struct reduce_unary_operator<T, ReduceTensorOp::NORM1, true, IsLastReduce>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
}; };
template <typename T, bool IsLastReduce> template <typename T, bool IsLastReduce>
struct reduce_unary_operator<T, ReduceTensorOp_t::AMAX, true, IsLastReduce> struct reduce_unary_operator<T, ReduceTensorOp::AMAX, true, IsLastReduce>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
}; };
template <typename T> template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, false> struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, false>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
}; };
template <typename T> template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, true, true> struct reduce_unary_operator<T, ReduceTensorOp::NORM2, true, true>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnarySquare<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
}; };
template <typename T> template <typename T>
struct reduce_unary_operator<T, ReduceTensorOp_t::NORM2, false, true> struct reduce_unary_operator<T, ReduceTensorOp::NORM2, false, true>
{ {
using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>; using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic<T, T>;
using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>; using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt<T, T>;
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -179,10 +180,10 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -179,10 +180,10 @@ struct GridwiseReduction_mk_to_m_blockwise
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
// Dim_K as the fastest one make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( using ThreadReduceDstDesc_M =
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
...@@ -216,32 +217,33 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -216,32 +217,33 @@ struct GridwiseReduction_mk_to_m_blockwise
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using Accumulation =
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global; (void)p_ws_indices_global;
(void)p_indices_global; (void)p_indices_global;
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -288,38 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -288,38 +290,26 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++; reducedTiles++;
} while(reducedTiles < toReduceTiles); } while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I)); static_for<0, MThreadSliceSize, 1>{}(
}); [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -336,7 +326,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -336,7 +326,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf; priorDstValueBuf;
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -376,7 +366,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -376,7 +366,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
...@@ -417,35 +407,34 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -417,35 +407,34 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)p_ws_indices_global; (void)p_ws_indices_global;
// LDS // LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, index_t, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -498,42 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -498,42 +487,36 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_val_buf); in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values // initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) = in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J(); indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset)); in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice AccumulationWithIndex::Calculate(tmpValue,
AccumulationWithIndex::Calculate( in_thread_val_buf[Number<offset>{}],
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]); tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
}); });
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce( BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex); reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
...@@ -542,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -542,8 +525,7 @@ struct GridwiseReduction_mk_to_m_blockwise
reducedTiles++; reducedTiles++;
} while(reducedTiles < toReduceTiles); } while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -561,7 +543,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -561,7 +543,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf; priorDstValueBuf;
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -601,7 +583,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -601,7 +583,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -619,7 +601,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -619,7 +601,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -672,42 +654,38 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -672,42 +654,38 @@ struct GridwiseReduction_mk_to_m_blockwise
(void)in_elementwise_op; (void)in_elementwise_op;
// LDS // LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_block_reduce_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_ws_values_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize()); p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
true> true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -756,8 +734,6 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -756,8 +734,6 @@ struct GridwiseReduction_mk_to_m_blockwise
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
// index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
...@@ -782,42 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -782,42 +758,33 @@ struct GridwiseReduction_mk_to_m_blockwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_idx_buf); in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal; AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice AccumulationWithIndex::Calculate(tmpValue,
AccumulationWithIndex::Calculate( in_thread_val_buf[Number<offset>{}],
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]); tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
}); });
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce( BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex); reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
// indexOffset += K_BlockTileSize;
reducedTiles++; reducedTiles++;
} while(reducedTiles < toReduceTiles); } while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -835,7 +802,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -835,7 +802,7 @@ struct GridwiseReduction_mk_to_m_blockwise
{ {
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf; priorDstValueBuf;
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -875,7 +842,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -875,7 +842,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
...@@ -893,7 +860,7 @@ struct GridwiseReduction_mk_to_m_blockwise ...@@ -893,7 +860,7 @@ struct GridwiseReduction_mk_to_m_blockwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -103,10 +104,10 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -103,10 +104,10 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
// Dim_K as the fastest one make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( using ThreadReduceDstDesc_M =
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType, using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize, BlockSize,
...@@ -115,6 +116,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -115,6 +116,12 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -138,23 +145,20 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -138,23 +145,20 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto block_reduce_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -201,42 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -201,42 +205,30 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its // Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a // reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load, // consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions. // each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}(
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
...@@ -259,7 +251,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add ...@@ -259,7 +251,7 @@ struct GridwiseReduction_mk_to_m_multiblock_atomic_add
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
1, 1,
true>( true>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp" #include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -121,10 +122,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -121,10 +122,10 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr auto thread_cluster_desc = static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
// For laying out the threads to do reducing on LDS buffer, for LDS buffer, we always use the using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
// Dim_K as the fastest one make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( using ThreadReduceDstDesc_M =
make_tuple(Number<MThreadClusterSize>{}, Number<KThreadClusterSize>{})); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
...@@ -151,8 +152,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -151,8 +152,11 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
using Accumulation = using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global; (void)p_ws_indices_global;
(void)acc_elementwise_op; (void)acc_elementwise_op;
...@@ -160,25 +164,22 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -160,25 +164,22 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_block_reduce_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -225,20 +226,17 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -225,20 +226,17 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++; reducedTiles++;
...@@ -246,16 +244,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -246,16 +244,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
// Each block executes multiple parallel reductions on the LDS, and due to the using of // Each block executes multiple parallel reductions on the LDS, and due to the using of
// vector_load, each block/thread is involved into multiple invarirant dimensions. // vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}(
block_reduce_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
accu_value_buf[I];
accu_value_buf(I) = zeroVal;
__syncthreads();
BlockwiseReduce::Reduce(block_reduce_buf, accu_value_buf(I));
});
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
...@@ -272,7 +262,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -272,7 +262,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, workspace_desc_m_k,
...@@ -318,37 +308,33 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -318,37 +308,33 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_block_reduce_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_block_reduce_idx_buffer[BlockSize]; __shared__ index_t p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf = const auto in_global_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize());
auto block_reduce_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto block_reduce_idx_buf = auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Lds>(p_block_reduce_idx_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
true> true>
in_thread_idx_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -401,42 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -401,42 +387,36 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_val_buf); in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values // initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(offset) = in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + J(); indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(offset), in_thread_val_buf(offset)); in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// reduce on the dim1 thread slice AccumulationWithIndex::Calculate(tmpValue,
AccumulationWithIndex::Calculate( in_thread_val_buf[Number<offset>{}],
tmpValue, in_thread_val_buf[offset], tmpIndex, in_thread_idx_buf[offset]); tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
}); });
// store thread local value to LDS for parallel reduction
block_reduce_val_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpValue;
block_reduce_idx_buf(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) =
tmpIndex;
__syncthreads();
BlockwiseReduceWithIndex::Reduce( BlockwiseReduceWithIndex::Reduce(
block_reduce_val_buf, block_reduce_idx_buf, tmpValue, tmpIndex); reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate( AccumulationWithIndex::Calculate(
accu_value_buf(I), tmpValue, accu_index_buf(I), tmpIndex); accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
...@@ -461,7 +441,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -461,7 +441,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, workspace_desc_m_k,
...@@ -480,7 +460,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -480,7 +460,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, workspace_desc_m_k,
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
...@@ -110,6 +111,11 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -110,6 +111,11 @@ struct GridwiseReduction_mk_to_m_threadwise
using ThreadBufferDimAccessOrder = using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type; typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -124,26 +130,25 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -124,26 +130,25 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
using Accumulation = ThreadReduceSrcDesc_M_K,
detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_indices_global; (void)p_indices_global;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
...@@ -178,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -178,20 +183,17 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_elementwise_op(in_thread_buf(Number<offset>{}),
}); in_thread_buf(Number<offset>{}));
// reduce on each thread-local slice
static_for<0, KThreadSliceSize, 1>{}([&](auto J) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
Accumulation::Calculate(accu_value_buf(I), in_thread_buf[offset]);
}); });
}); });
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
...@@ -203,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -203,8 +205,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha; accu_value_buf(I) *= alpha;
}); });
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
if constexpr(!BetaIsZero) if constexpr(!BetaIsZero)
{ {
...@@ -223,7 +224,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -223,7 +224,7 @@ struct GridwiseReduction_mk_to_m_threadwise
true>( true>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf; priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m, threadwise_dst_load.Run(out_grid_desc_m,
...@@ -248,7 +249,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -248,7 +249,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -269,30 +270,35 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -269,30 +270,35 @@ struct GridwiseReduction_mk_to_m_threadwise
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_indices_global)
{ {
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan, using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation, ReduceOperation,
AccDataType, PropagateNan>;
IndexDataType>;
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_indices_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
AccDataType, in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
true> true>
in_thread_buf; in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum_t::Vgpr, IndexDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
...@@ -329,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -329,26 +335,23 @@ struct GridwiseReduction_mk_to_m_threadwise
in_global_buf, in_global_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = I * Number<KThreadSliceSize>{} + J; constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(offset), in_thread_buf(offset)); in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
});
// reduce on each thread-local slice in_elementwise_op(in_thread_val_buf(Number<offset>{}),
static_for<0, KThreadSliceSize, 1>{}([&](auto J) { in_thread_val_buf(Number<offset>{}));
constexpr auto offset = I * Number<KThreadSliceSize>{} + J;
AccumulationWithIndex::Calculate(accu_value_buf(I),
in_thread_buf[offset],
accu_index_buf(I),
indexStart + J);
}); });
}); });
ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
...@@ -362,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -362,8 +365,7 @@ struct GridwiseReduction_mk_to_m_threadwise
accu_value_buf(I) *= alpha; accu_value_buf(I) *= alpha;
}); });
constexpr auto reduced_data_desc = constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
if constexpr(!BetaIsZero) if constexpr(!BetaIsZero)
{ {
...@@ -382,7 +384,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -382,7 +384,7 @@ struct GridwiseReduction_mk_to_m_threadwise
false>( false>(
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
StaticBuffer<AddressSpaceEnum_t::Vgpr, OutDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValue_buf; priorDstValue_buf;
threadwise_dst_load.Run(out_grid_desc_m, threadwise_dst_load.Run(out_grid_desc_m,
...@@ -407,7 +409,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -407,7 +409,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -424,7 +426,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -424,7 +426,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
......
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1, typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1, typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1, typename CGridDesc_GM0_GM1_GN0_GN1,
...@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -329,11 +329,11 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
...@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -383,7 +383,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>, Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
...@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -407,7 +407,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>, Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
...@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -467,7 +467,7 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
...@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN ...@@ -481,15 +481,15 @@ struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
......
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -268,11 +268,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0); const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
...@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -315,7 +315,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1>, Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1, ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1, ABlockTransferThreadClusterLengths_K_M0_M1,
...@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -341,7 +341,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1>, Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1, BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1, BBlockTransferThreadClusterLengths_K_N0_N1,
...@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -403,7 +403,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
...@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2 ...@@ -428,15 +428,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r2
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{}; BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize()); a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize()); b_k_n0_n1_block_desc.GetElementSpaceSize());
......
...@@ -55,7 +55,7 @@ template <index_t BlockSize, ...@@ -55,7 +55,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
...@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -275,11 +275,11 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -325,7 +325,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
...@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -349,7 +349,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
...@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -409,7 +409,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, ThreadwiseTensorSliceSet_v1<FloatAcc,
...@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -423,15 +423,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
......
...@@ -15,7 +15,7 @@ template <index_t BlockSize, ...@@ -15,7 +15,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
...@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -84,11 +84,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize()); p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3; constexpr auto E = EPerBlock * 3 * 3;
...@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -181,7 +181,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<E, KPerBlock>, Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
...@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -221,11 +221,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
b_e_n_ho_wo_global_desc, b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize()); p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -250,7 +250,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BGlobalMoveSliceWindowStepHacks{}; BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB, FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(), b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true> true>
......
...@@ -20,7 +20,7 @@ template <typename GridwiseGemm, ...@@ -20,7 +20,7 @@ template <typename GridwiseGemm,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -50,7 +50,7 @@ __global__ void ...@@ -50,7 +50,7 @@ __global__ void
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor, cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}, integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{}); integral_constant<ActivTypeEnum, ActivType>{});
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -62,7 +62,7 @@ template <typename GridwiseGemm, ...@@ -62,7 +62,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -94,7 +94,7 @@ __global__ void ...@@ -94,7 +94,7 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor, cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}, integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{}); integral_constant<ActivTypeEnum, ActivType>{});
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -106,7 +106,7 @@ template <typename GridwiseGemm, ...@@ -106,7 +106,7 @@ template <typename GridwiseGemm,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -140,14 +140,14 @@ __global__ void ...@@ -140,14 +140,14 @@ __global__ void
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor, cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}, integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum_t, ActivType>{}); integral_constant<ActivTypeEnum, ActivType>{});
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2, typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
...@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -559,7 +559,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto bias_k0_k1_thread_desc = constexpr auto bias_k0_k1_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC, FloatC,
bias_k0_k1_thread_desc.GetElementSpaceSize(), bias_k0_k1_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -602,10 +602,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}); });
} }
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum_t activ_type_> template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum activ_type_>
__device__ static void Activation(CThreadBuff& c_thread_buf, __device__ static void Activation(CThreadBuff& c_thread_buf,
const CThreadDesc_K1_N_H2_W2&, const CThreadDesc_K1_N_H2_W2&,
integral_constant<ActivTypeEnum_t, activ_type_>) integral_constant<ActivTypeEnum, activ_type_>)
{ {
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
...@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -737,7 +737,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1, I1,
Number<WoPerThread_2>{})); Number<WoPerThread_2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC, FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(), d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -783,7 +783,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
...@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -843,7 +843,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
I1, I1,
Number<WoPerThreadx2>{})); Number<WoPerThreadx2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC, FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(), d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true> true>
...@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -874,7 +874,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Add, InMemoryDataOperationEnum::Add,
1, 1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
...@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -964,7 +964,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize, BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<I1, E1, I1, KPerBlock, E2>, Sequence<I1, E1, I1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
...@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1023,11 +1023,11 @@ struct GridwiseGemmDlops_km_kn_mn_v3
0, 0,
0)); 0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
//// register allocation for output //// register allocation for output
// StaticBuffer<AddressSpaceEnum_t::Vgpr, // StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc, // FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), // c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true> // true>
...@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1050,7 +1050,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b // double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB, FloatAB,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(), b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
true> true>
...@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1294,21 +1294,21 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
...@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1344,7 +1344,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__device__ static void ConvBiasActiv( __device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
...@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1356,26 +1356,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>, integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>) integral_constant<ActivTypeEnum, ActivType>)
{ {
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{}; static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
...@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1423,7 +1423,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__device__ static void ConvBiasActivMaxpool( __device__ static void ConvBiasActivMaxpool(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
...@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1437,28 +1437,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>, integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>) integral_constant<ActivTypeEnum, ActivType>)
{ {
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{}; static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
...@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1514,7 +1514,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop, bool HasMainE0BlockLoop,
ActivTypeEnum_t ActivType> ActivTypeEnum ActivType>
__device__ static void ConvBiasActivResizeAdd( __device__ static void ConvBiasActivResizeAdd(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
...@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1527,26 +1527,26 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>, integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum_t, ActivType>) integral_constant<ActivTypeEnum, ActivType>)
{ {
static constexpr auto activ_type = integral_constant<ActivTypeEnum_t, ActivType>{}; static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output // register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true> true>
......
...@@ -79,8 +79,8 @@ template <typename FloatAB, ...@@ -79,8 +79,8 @@ template <typename FloatAB,
typename CElementwiseOperation, typename CElementwiseOperation,
typename D0ReduceOperation, typename D0ReduceOperation,
typename D1ReduceOperation, typename D1ReduceOperation,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
InMemoryDataOperationEnum_t DGlobalMemoryDataOperation, InMemoryDataOperationEnum DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -363,15 +363,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -363,15 +363,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -399,7 +399,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -399,7 +399,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -430,7 +430,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -430,7 +430,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -484,10 +484,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -484,10 +484,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
...@@ -563,7 +563,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -563,7 +563,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -632,7 +632,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -632,7 +632,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
...@@ -723,13 +723,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -723,13 +723,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{})); make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
// TODO: this should be implemented as a blockwise reduction // TODO: this should be implemented as a blockwise reduction
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>( auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); d_reduce_thread_desc_mperblock.GetElementSpaceSize());
auto d1_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatCShuffle>( auto d1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatCShuffle>(
d_reduce_thread_desc_mperblock.GetElementSpaceSize()); d_reduce_thread_desc_mperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR // reduce: threadwise copy from LDS to VGPR
......
...@@ -60,7 +60,7 @@ template <typename FloatAB, ...@@ -60,7 +60,7 @@ template <typename FloatAB,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -316,11 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
...@@ -348,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -348,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -379,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -379,7 +379,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -433,10 +433,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -433,10 +433,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
...@@ -512,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -512,7 +512,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared), static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -581,7 +581,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -581,7 +581,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, 7,
1, 1,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......
...@@ -132,7 +132,7 @@ template <index_t BlockSize, ...@@ -132,7 +132,7 @@ template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
...@@ -426,11 +426,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -426,11 +426,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
...@@ -460,7 +460,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -460,7 +460,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>, Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -491,7 +491,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -491,7 +491,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BlockwiseTensorSliceTransfer_v4r1<BlockSize, BlockwiseTensorSliceTransfer_v4r1<BlockSize,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
...@@ -543,10 +543,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -543,10 +543,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto a_block_space_size_aligned = constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize()); b_block_desc_k0_n_k1.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment