Commit a3b4c5cb authored by wangshaojie6's avatar wangshaojie6
Browse files

merge develop branch and add gridwise pipeline v3

parents 48918ab9 1677cf70
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "device_prop.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME #ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 1 #define CK_RUN_KERNEL_AND_TIME 1
...@@ -332,17 +333,16 @@ struct DeviceGemmXdlSplitK ...@@ -332,17 +333,16 @@ struct DeviceGemmXdlSplitK
K, N, StrideB, k_batch_, KPad); K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_, b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
M01_, block_2_ctile_map_))
N01_))
{ {
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
} }
} }
...@@ -385,21 +385,24 @@ struct DeviceGemmXdlSplitK ...@@ -385,21 +385,24 @@ struct DeviceGemmXdlSplitK
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.block_2_ctile_map_))
arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
...@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK ...@@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0) // FIXME: this should be moved outside of DeviceOp
{ hipGetErrorString(
ShowInfo(arg); hipMemset(arg.p_c_grid_,
ave_time = launch_and_time_kernel(kernel, 0,
nrepeat, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
dim3(grid_size), sizeof(CDataType)));
dim3(BlockSize),
0, ave_time = launch_and_time_kernel(stream_config,
arg.p_a_grid_, kernel,
arg.p_b_grid_, dim3(grid_size),
arg.p_c_grid_, dim3(BlockSize),
arg.a_grid_desc_kbatch_k0_m_k1_, 0,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.p_a_grid_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.p_b_grid_,
arg.a_element_op_, arg.p_c_grid_,
arg.b_element_op_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.c_element_op_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.block_2_ctile_map_); arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
} arg.a_element_op_,
arg.b_element_op_,
if(kbatch > 1 || nrepeat <= 0) arg.c_element_op_,
{ arg.block_2_ctile_map_);
hipGetErrorString(
hipMemset(arg.p_c_grid_,
0,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(kbatch == 1)
...@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK ...@@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -545,11 +529,15 @@ struct DeviceGemmXdlSplitK ...@@ -545,11 +529,15 @@ struct DeviceGemmXdlSplitK
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.block_2_ctile_map_);
arg.N01_);
} }
// polymorphic // polymorphic
......
...@@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor;
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle
K, N, StrideB, k_batch_, KPad); K, N, StrideB, k_batch_, KPad);
c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_, b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
M01_, block_2_ctile_map_))
N01_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
} }
} }
...@@ -391,21 +389,24 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -391,21 +389,24 @@ struct DeviceGemmXdlSplitKCShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, int nrepeat = 1)
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg);
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.block_2_ctile_map_))
arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting");
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
...@@ -414,51 +415,29 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -414,51 +415,29 @@ struct DeviceGemmXdlSplitKCShuffle
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
if(nrepeat > 0) hipGetErrorString(hipMemset(
{ arg.p_c_grid_,
ShowInfo(arg); 0,
ave_time = arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
launch_and_time_kernel(kernel, sizeof(CDataType)));
nrepeat,
dim3(grid_size), launch_and_time_kernel(stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_c_grid_, arg.p_a_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.p_b_grid_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.p_c_grid_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.a_element_op_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.b_element_op_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c_element_op_, arg.a_element_op_,
arg.block_2_ctile_map_); arg.b_element_op_,
} arg.c_element_op_,
arg.block_2_ctile_map_);
if(kbatch > 1 || nrepeat <= 0)
{
hipGetErrorString(hipMemset(
arg.p_c_grid_,
0,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
sizeof(CDataType)));
launch_kernel(kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(kbatch == 1) if(kbatch == 1)
...@@ -542,9 +521,10 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -542,9 +521,10 @@ struct DeviceGemmXdlSplitKCShuffle
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -559,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -559,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.block_2_ctile_map_);
arg.N01_);
} }
// polymorphic // polymorphic
......
...@@ -17,6 +17,62 @@ namespace ck { ...@@ -17,6 +17,62 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0;
for(index_t i = 0; i < group_count; i++)
{
group_id =
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
? i
: group_id;
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
...@@ -225,6 +281,11 @@ struct DeviceGroupedGemmXdl ...@@ -225,6 +281,11 @@ struct DeviceGroupedGemmXdl
struct GroupedGemmBlock2CTileMap struct GroupedGemmBlock2CTileMap
{ {
using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)),
typename GridwiseGemm::DefaultBlock2CTileMap>::value,
"Wrong! Should be the same type name");
GroupedGemmBlock2CTileMap() GroupedGemmBlock2CTileMap()
{ {
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1);
...@@ -247,7 +308,18 @@ struct DeviceGroupedGemmXdl ...@@ -247,7 +308,18 @@ struct DeviceGroupedGemmXdl
make_multi_index(idx_top[I0] - BlockStart_)); make_multi_index(idx_top[I0] - BlockStart_));
} }
private: template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
ck::index_t BlockStart_; ck::index_t BlockStart_;
}; };
...@@ -290,17 +362,20 @@ struct DeviceGroupedGemmXdl ...@@ -290,17 +362,20 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
group_count_ = static_cast<int>(gemm_shapes.size()); gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
group_count_ == p_c.size())) group_count_ == ck::type_convert<ck::index_t>(p_b.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_c.size())))
{ {
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
} }
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
for(index_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
{ {
const index_t M = gemm_shapes[i].M; const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_shapes[i].N;
...@@ -317,22 +392,26 @@ struct DeviceGroupedGemmXdl ...@@ -317,22 +392,26 @@ struct DeviceGroupedGemmXdl
const auto c_grid_desc_m_n_ = const auto c_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_); const index_t grid_size_grp =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0)
.block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity( const auto grouped_gemm_block_2_ctile_map_ =
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
grouped_gemm_block_2_ctile_map_))
{ {
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
gemm_desc_kernel_arg_.push_back( gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_, GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_, b_grid_desc_k0_n_k1_,
...@@ -358,6 +437,8 @@ struct DeviceGroupedGemmXdl ...@@ -358,6 +437,8 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -366,83 +447,77 @@ struct DeviceGroupedGemmXdl ...@@ -366,83 +447,77 @@ struct DeviceGroupedGemmXdl
{ {
using Argument = DeviceGroupedGemmXdl::Argument; using Argument = DeviceGroupedGemmXdl::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_arg_arg; bool has_main_k_block_loop = true;
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
bool has_main_k0_block_loop = true; const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);
static_for<0, MaxGroupCount, 1>{}([&](auto i) { if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
if(i < arg.gemm_desc_kernel_arg_.size())
{ {
gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i]; throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", "
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1)
<< ", "
<< gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", "
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1)
<< ", "
<< gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2)
<< "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0);
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop");
}
} }
}); }
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0; float ave_time = 0;
if(has_main_k0_block_loop) if(has_main_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDescKernelArg>, GemmDescKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
true, true>;
MaxGroupCount>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel(kernel, stream_config,
nrepeat, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_arg_arg, cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
} }
else else
{ {
...@@ -450,32 +525,33 @@ struct DeviceGroupedGemmXdl ...@@ -450,32 +525,33 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDescKernelArg>, GemmDescKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
false, false>;
MaxGroupCount>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel(kernel, stream_config,
nrepeat, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_arg_arg, cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
} }
return ave_time; return ave_time;
} }
// polymorphic // polymorphic
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
...@@ -487,7 +563,7 @@ struct DeviceGroupedGemmXdl ...@@ -487,7 +563,7 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_) if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
return false; return false;
else else
return true; return true;
...@@ -554,6 +630,16 @@ struct DeviceGroupedGemmXdl ...@@ -554,6 +630,16 @@ struct DeviceGroupedGemmXdl
return str.str(); return str.str();
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
}; };
} // namespace device } // namespace device
......
...@@ -17,7 +17,7 @@ template <typename InDataType, ...@@ -17,7 +17,7 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool NeedIndices, bool OuputIndex,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize, ck::index_t ReduceMThreadClusterSize,
ck::index_t ReduceKThreadClusterSize, ck::index_t ReduceKThreadClusterSize,
...@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation; AccElementwiseOperation;
static constexpr bool BetaIsZero = true;
static constexpr index_t InSrcOutDstVectorDim = static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced. // not reduced.
...@@ -204,30 +202,30 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -204,30 +202,30 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType, using gridwise_reduce =
OutDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
AccDataType, OutDataType,
IndexDataType, AccDataType,
AGridDesc_M_K, IndexDataType,
BGridDesc_M, AGridDesc_M_K,
ReduceOperation, BGridDesc_M,
InElementwiseOperation, ReduceOperation,
AccElementwiseOperation, InElementwiseOperation,
false, // propagate_nan AccElementwiseOperation,
BetaIsZero, InMemoryDataOperationEnum::Set,
BlockSize, false, // propagate_nan
ReduceMThreadClusterSize, BlockSize,
ReduceKThreadClusterSize, ReduceMThreadSliceSize,
ReduceMThreadSliceSize, ReduceKThreadSliceSize,
ReduceKThreadSliceSize, InSrcOutDstVectorDim,
InSrcOutDstVectorDim, InSrcOutDstVectorSize,
InSrcOutDstVectorSize, InSrcOutDstVectorSize>;
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce, const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
NeedIndices, OuputIndex,
false, // don't have index input
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
...@@ -241,8 +239,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -241,8 +239,8 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
return launch_and_time_kernel(kernel, return launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -252,14 +250,16 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -252,14 +250,16 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.acc_element_op_, arg.acc_element_op_,
float(1), float(1),
arg.p_in_dev_, arg.p_in_dev_,
nullptr,
float(0), float(0),
arg.p_out_dev_, arg.p_out_dev_,
arg.p_out_indices_dev_); arg.p_out_indices_dev_);
} }
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} }
}; };
......
...@@ -16,35 +16,18 @@ namespace device { ...@@ -16,35 +16,18 @@ namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims)
{
(void)inLengths;
(void)reduceDims;
return (0);
};
virtual bool HasFurtherCall() { return (false); };
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
{
(void)argPtr;
return (std::vector<int>{0, 0});
};
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) = 0; const AccElementwiseOperation acc_elementwise_op) = 0;
......
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
#define DEVICE_REDUCE_BLOCKWISE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto inPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, inPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
{
const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m =
DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
nullptr,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWiseSecondCall
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static_assert(
std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{});
const auto in_grid_desc_m_k =
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
invariant_total_length = inLengths[0];
reduce_total_length = inLengths[1];
invariant_lowest_length = inLengths[0];
reduce_lowest_length = inLengths[1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * reduce_total_length * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<index_t*>(
reinterpret_cast<char*>(workspace_dev) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
{
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor(
arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise_second_call<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
arg.workspace_indices_dev_,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
(void)reduceDims;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -14,13 +14,13 @@ namespace device { ...@@ -14,13 +14,13 @@ namespace device {
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those // here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// of reduce dims // of reduce dims
template <int Rank, int NumReduceDim> template <index_t Rank, int NumReduceDim>
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 6, "bigger Rank size not supported!");
size_t invariant_total_length = 1; long_index_t invariant_total_length = 1;
size_t reduce_total_length = 1; long_index_t reduce_total_length = 1;
constexpr int NumInvariantDim = Rank - NumReduceDim; constexpr int NumInvariantDim = Rank - NumReduceDim;
...@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) ...@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>) auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
{ {
return make_tuple(static_cast<index_t>(lengths[Ns])...); return make_tuple(static_cast<index_t>(lengths[Ns])...);
}; };
template <index_t arraySize> template <index_t arraySize>
static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arraySize>) auto make_tuple_from_array(const std::vector<index_t>& lengths, Number<arraySize>)
{ {
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
...@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS ...@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
}; };
template <index_t Rank, index_t NumReduceDim> template <index_t Rank, index_t NumReduceDim>
std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides, std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origLengthsStrides,
const std::vector<int>& reduceDims) const std::vector<int>& reduceDims)
{ {
std::vector<int> newLengthsStrides; std::vector<index_t> newLengthsStrides;
assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP #ifndef DEVICE_REDUCE_MULTIBLOCK_HPP
#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP #define DEVICE_REDUCE_MULTIBLOCK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_common.hpp" #include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_atomic_add.hpp" #include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_set_buffer_value.hpp" #include "gridwise_set_buffer_value.hpp"
#include "reduction_operator.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -22,8 +23,10 @@ template <typename InDataType, ...@@ -22,8 +23,10 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -32,8 +35,7 @@ template <typename InDataType, ...@@ -32,8 +35,7 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlockAtomicAdd struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd = // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// later
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static constexpr bool out_type_compatible_with_atomic_op =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value; std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
static_assert(!NeedIndices && support_AtomicAdd, static_assert(
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when " !use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
"having float/double output type!"); "The OutDataType must support the atomic operation for using MultiBlock reduction");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
"MultiBlock reduction can only be used when outputing index is not required");
static_assert(
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<int>& inStrides, const std::vector<index_t>& inStrides,
int blkGroupSize, int blkGroupSize,
int kBlockTileIterations) int numBlockTileIteration)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M = const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
...@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<int>& outStrides) const std::vector<index_t>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
...@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{});
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(length, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
const IndexDataType* in_index_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_index_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
in_index_dev_{in_index_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
(void)out_indices_dev;
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd
reduce_lowest_length = inLengths_[Rank - 1]; reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1; if constexpr(use_multiblock)
while(true)
{ {
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 int iterations = 1;
if(testBlkGroupSize <= 128) while(true)
break; {
int testBlkGroupSize =
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
iterations++; // we want the blkGroupSize be not more than 128
}; if(testBlkGroupSize <= 128)
break;
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / iterations++;
(K_BlockTileSize * iterations); };
kBlockTileIterations = iterations; blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration = iterations;
}
else
{
blkGroupSize = 1;
numBlockTileIteration =
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize; M_BlockTileSize * blkGroupSize;
...@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
} }
std::vector<int> inLengths_; std::vector<index_t> inLengths_;
std::vector<int> inStrides_; std::vector<index_t> inStrides_;
std::vector<int> outLengths_; std::vector<index_t> outLengths_;
std::vector<int> outStrides_; std::vector<index_t> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
const IndexDataType* in_index_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_; AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length; index_t invariant_lowest_length;
int reduce_lowest_length; index_t reduce_lowest_length;
size_t invariant_total_length; long_index_t invariant_total_length;
size_t reduce_total_length; long_index_t reduce_total_length;
index_t blkGroupSize; int blkGroupSize;
index_t kBlockTileIterations; int numBlockTileIteration;
size_t gridSize; size_t gridSize;
size_t gridSize_pre; size_t gridSize_pre;
...@@ -245,91 +299,97 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -245,91 +299,97 @@ struct DeviceReduceMultiBlockAtomicAdd
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor( const auto out_grid_desc_m =
DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet(
arg.outLengths_, arg.outStrides_); arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce =
GridwiseReduction_mk_to_m_multiblock_atomic_add<InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
KernelTimer timer;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>; using InGridDesc_M_K = decltype(in_grid_desc_m_k);
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce, using OutGridDesc_M = decltype(out_grid_desc_m);
InDataType, using OutGridDesc_M_2 = decltype(out_grid_desc_m_2);
OutDataType,
AccDataType, using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock<InDataType,
InGridDesc_M_K, OutDataType,
OutGridDesc_M, AccDataType,
InElementwiseOperation, IndexDataType,
AccElementwiseOperation>; InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
OutMemoryDataOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel_main = kernel_reduce_multiblock<GridwiseReduce,
OutputIndex,
HaveIndexInput,
InDataType,
OutDataType,
AccDataType,
int32_t,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n", float avg_time = 0;
arg.gridSize,
BlockSize);
printf("Warm up\n");
for(int i = 0; i < nrepeat + 1; i++) if constexpr(use_multiblock)
{ {
if(i == 1) const auto identityVal =
timer.Start(); ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);
launch_kernel(kernel_pre,
dim3(arg.gridSize_pre), const auto kernel_pre =
dim3(BlockSize), kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M_2>;
0,
out_grid_desc_m, avg_time += launch_and_time_kernel(stream_config,
arg.out_dev_, kernel_pre,
static_cast<OutDataType>(0.0f)); dim3(arg.gridSize_pre),
dim3(BlockSize),
launch_kernel(kernel_main, 0,
dim3(arg.gridSize), out_grid_desc_m_2,
dim3(BlockSize), arg.out_dev_,
0, identityVal);
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.alpha_,
arg.in_dev_,
arg.out_dev_);
}; };
timer.End(); avg_time += launch_and_time_kernel(stream_config,
kernel_main,
avg_time = timer.GetElapsedTime() / nrepeat; dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.numBlockTileIteration,
arg.alpha_,
arg.in_dev_,
arg.in_index_dev_,
arg.beta_,
arg.out_dev_,
arg.out_index_dev_);
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; };
}; };
...@@ -337,6 +397,12 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -337,6 +397,12 @@ struct DeviceReduceMultiBlockAtomicAdd
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(use_multiblock)
{
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
};
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -361,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -361,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
return (false); return (false);
}; };
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
// To improve // To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0) if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false); return (false);
// cases with small reduce_total_length should be handled by the BlockWise method if constexpr(use_multiblock)
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize) {
return (false); // blkGroupSize of 1 should be handled by Blockwise path using
// InMemoryDataOperationEnum::Set
if(pArg->blkGroupSize == 1)
return (false);
// This is very strong restriction, but needed to avoid some failure // This is very strong restriction, but needed to avoid some failure
if(pArg->invariant_lowest_length % M_BlockTileSize != 0) if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
return (false); return (false);
}
else
{
// cases with very small reduce_total_length should be handled by ThreadWise kernel
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
};
return (true); return (true);
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
...@@ -402,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -402,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<const IndexDataType*>(in_index_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev), static_cast<IndexDataType*>(out_index_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); acc_elementwise_op);
}; };
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlockPartialReduce
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr int MaxBlockGroupSize = 256;
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims) override
{
size_t invariant_total_length;
size_t reduce_total_length;
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
long_index_t workspace_size = invariant_total_length * blkGroupSize;
long_index_t wsSizeInBytes =
!NeedIndices
? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
return (wsSizeInBytes);
};
bool HasFurtherCall() override { return (true); };
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
int blkGroupSize,
int kBlockTileIterations)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
{
auto ws_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto wsPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto ws_desc_m_k_padded =
transform_tensor_descriptor(ws_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, wsPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (ws_desc_m_k_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
kBlockTileIterations = iterations;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * blkGroupSize * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<int*>(
reinterpret_cast<char*>(workspace_dev_) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
AccDataType* workspace_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
index_t blkGroupSize;
index_t kBlockTileIterations;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor(
arg.invariant_total_length, arg.blkGroupSize);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using WorkspaceDesc_M_K = decltype(ws_desc_m_k);
using GridwiseReduce =
GridwiseReduction_mk_to_mk_multiblock_partial_reduce<InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_partial_reduce_multiblock<GridwiseReduce,
NeedIndices,
InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
ws_desc_m_k,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.in_dev_,
arg.workspace_dev_,
arg.workspace_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(OutDstVectorSize != 1)
return (false);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// cases with small reduce_total_length should be handled by the BlockWise method
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
return (false);
return (true);
};
std::vector<int> GetWorkspace2dLengths(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
return (
std::vector<int>{static_cast<int>(pArg->invariant_total_length), pArg->blkGroupSize});
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_common.hpp" #include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_2d_reduction_threadwise.hpp" #include "gridwise_2d_reduction_threadwise.hpp"
namespace ck { namespace ck {
...@@ -19,22 +20,19 @@ template <typename InDataType, ...@@ -19,22 +20,19 @@ template <typename InDataType,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename AccElementwiseOperation,
bool PropagateNan, bool PropagateNan,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutElementwiseOperation> struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
"Threadwise can only be called with KThreadClusterSize be 1 !");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
...@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
...@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<int>& inStrides) const std::vector<index_t>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<int>& outStrides) const std::vector<index_t>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
...@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_index_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_indices_dev_{out_indices_dev}, out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -183,36 +177,39 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -183,36 +177,39 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
reduce_lowest_length = inLengths_[Rank - 1]; reduce_lowest_length = inLengths_[Rank - 1];
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
} }
std::vector<int> inLengths_; std::vector<index_t> inLengths_;
std::vector<int> inStrides_; std::vector<index_t> inStrides_;
std::vector<int> outLengths_; std::vector<index_t> outLengths_;
std::vector<int> outStrides_; std::vector<index_t> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_indices_dev_; IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
OutElementwiseOperation acc_elementwise_op_; AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length; index_t invariant_lowest_length;
int reduce_lowest_length; index_t reduce_lowest_length;
size_t invariant_total_length; long_index_t invariant_total_length;
size_t reduce_total_length; long_index_t reduce_total_length;
int numBlockTileIteration;
size_t gridSize; size_t gridSize;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = const auto in_grid_desc_m_k =
DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
...@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using InGridDesc_M_K = decltype(in_grid_desc_m_k); using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m); using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0; float avg_time = 0;
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<GridwiseReduce, const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
NeedIndices, OutputIndex,
HaveIndexInput,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
...@@ -252,10 +249,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -252,10 +249,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InGridDesc_M_K, InGridDesc_M_K,
OutGridDesc_M, OutGridDesc_M,
InElementwiseOperation, InElementwiseOperation,
OutElementwiseOperation>; AccElementwiseOperation>;
avg_time = launch_and_time_kernel(kernel, avg_time = launch_and_time_kernel(stream_config,
nrepeat, kernel,
dim3(arg.gridSize), dim3(arg.gridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
...@@ -265,16 +262,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -265,16 +262,18 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
arg.acc_elementwise_op_, arg.acc_elementwise_op_,
arg.alpha_, arg.alpha_,
arg.in_dev_, arg.in_dev_,
nullptr,
arg.beta_, arg.beta_,
arg.out_dev_, arg.out_dev_,
arg.out_indices_dev_); arg.out_index_dev_);
return (avg_time); return (avg_time);
}; };
float Run(const BaseArgument* p_arg, int nrepeat = 1) override float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}; };
}; };
...@@ -310,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -310,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if(pArg->invariant_lowest_length % OutDstVectorSize != 0) if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false); return (false);
// TODO: remove this. Should return true, as long as this DeviceOP instance support this // cases with big reduce_total_length should be handled by Blockwise kernel
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
// better performance
if(pArg->reduce_total_length / KThreadSliceSize >= 32) if(pArg->reduce_total_length / KThreadSliceSize >= 32)
return (false); return (false);
...@@ -320,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -320,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
(void)in_index_dev;
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
outLengths, outLengths,
...@@ -343,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -343,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev), static_cast<IndexDataType*>(out_index_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); acc_elementwise_op);
}; };
...@@ -359,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -359,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceReducceThreadWise<" << BlockSize << ","; str << "DeviceReduceThreadWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on // clang-format on
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#pragma once
#include "data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace binary_element_wise {
template <typename Y, typename X1, typename X2>
struct Add;
template <>
struct Add<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 + src2;
}
};
template <>
struct Add<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 + x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
template <typename Y, typename X1, typename X2>
struct Substract;
template <>
struct Substract<double, double, double>
{
__host__ __device__ constexpr void
operator()(double& dst, const double& src1, const double& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<float, float, float>
{
__host__ __device__ constexpr void
operator()(float& dst, const float& src1, const float& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<half_t, half_t, half_t>
{
__host__ __device__ constexpr void
operator()(half_t& dst, const half_t& src1, const half_t& src2) const
{
dst = src1 - src2;
}
};
template <>
struct Substract<bhalf_t, bhalf_t, bhalf_t>
{
__host__ __device__ constexpr void
operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const
{
const float x1 = ck::type_convert<float>(src1);
const float x2 = ck::type_convert<float>(src2);
const float y = x1 - x2;
dst = ck::type_convert<bhalf_t>(y);
}
};
} // namespace binary_element_wise
} // namespace tensor_operation
} // namespace ck
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -143,35 +144,22 @@ struct AddHardswishAdd ...@@ -143,35 +144,22 @@ struct AddHardswishAdd
} }
}; };
struct RequantReluRequant struct Normalize
{ {
// FIXME: We just need one scale for Relu / Leaky Relu / PRelu Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
RequantReluRequant(float scaleGemm, float scaleRelu)
: scaleGemm_(scaleGemm), scaleRelu_(scaleRelu) __host__ __device__ constexpr void operator()(float& y,
const float& x,
const float& mean,
const float& mean_square,
const float& gamma,
const float& beta) const
{ {
float variance = mean_square - (mean * mean);
y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta;
} }
__host__ __device__ constexpr void operator()(int8_t& y, const int& x) const float epsilon_;
{
float gemm_requant = scaleGemm_ * static_cast<float>(x);
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<int8_t>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
// for reference_gemm
__host__ __device__ constexpr void operator()(float& y, const float& x) const
{
float gemm_requant = scaleGemm_ * x;
float relu = gemm_requant > 0 ? gemm_requant : 0;
float relu_requant = scaleRelu_ * relu;
y = static_cast<float>(relu_requant > 127 ? 127
: relu_requant < -128 ? -128 : relu_requant);
}
float scaleGemm_;
float scaleRelu_;
}; };
// Unary operators are usually called element-wisely before/after the reduction is executed on the // Unary operators are usually called element-wisely before/after the reduction is executed on the
...@@ -309,7 +297,7 @@ struct UnaryAbs<float, float> ...@@ -309,7 +297,7 @@ struct UnaryAbs<float, float>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -317,7 +305,7 @@ struct UnaryAbs<half_t, half_t> ...@@ -317,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -325,7 +313,7 @@ struct UnaryAbs<double, double> ...@@ -325,7 +313,7 @@ struct UnaryAbs<double, double>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; __host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -333,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t> ...@@ -333,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
}; };
template <typename Y, typename X> template <typename Y, typename X>
...@@ -349,7 +332,7 @@ struct UnarySqrt<float, float> ...@@ -349,7 +332,7 @@ struct UnarySqrt<float, float>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
}; };
template <> template <>
...@@ -357,7 +340,10 @@ struct UnarySqrt<double, double> ...@@ -357,7 +340,10 @@ struct UnarySqrt<double, double>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; __host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -5,20 +5,6 @@ namespace ck { ...@@ -5,20 +5,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
struct ReduceSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v; }
};
struct ReduceSquareSum
{
__host__ __device__ static constexpr float GetReduceZeroValue() { return float(0); }
__host__ __device__ void Reduce(float& acc, float v) const { acc += v * v; }
};
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#ifndef UTILITY_BLOCK_TO_CTILE_MAP
#define UTILITY_BLOCK_TO_CTILE_MAP
#include "utility/math.hpp"
#include "utility/number.hpp"
#include "tensor_description/tensor_adaptor.hpp"
#include "tensor_description/multi_index_transform_helper.hpp"
namespace ck {
// Rows of column-vectors
template <index_t MPerBlock,
index_t NPerBlock,
typename CGridDesc_M_N,
bool DeviceCTileIndexCheck = false>
struct BlockToCTileMap_M00_N0_M01
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1)
: M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01_);
const index_t grid_size = M00 * M01_ * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
else
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
if(M0 % M01_ == 0)
{
return true;
}
else
{
return false;
}
}
private:
__host__ __device__ static constexpr auto
GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1),
make_unmerge_transform(make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_n0_m01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
index_t M01_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1));
UnderlyingMap underlying_map_;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8,
index_t KSplit = 1)
: M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n)
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const index_t grid_size = M0 * N0 * KSplit_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock);
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_ksplit,
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private:
index_t M01_;
index_t KSplit_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// Blocks of row-vectors
template <index_t MPerBlock,
index_t NPerBlock,
typename CGridDesc_M_N,
bool DeviceCTileIndexCheck = false>
struct BlockToCTileMap_M00_N00_M01_N01
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default;
__host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1,
index_t N01 = 1)
: M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01_);
const auto N00 = math::integer_divide_ceil(N0, N01_);
const index_t grid_size = M00 * M01_ * N00 * N01_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
else
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
if(M0 % M01_ == 0 && N0 % N01_ == 0)
{
return true;
}
else
{
return false;
}
}
private:
__host__ __device__ static constexpr auto
GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto N00 = math::integer_divide_ceil(N0, N01);
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
index_t M01_, N01_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1));
UnderlyingMap underlying_map_;
};
// 2D slices of row-vectors in 3D space
template <index_t MPerBlock,
index_t NPerBlock,
typename CGridDesc_M_N,
bool DeviceCTileIndexCheck = false>
struct BlockToCTileMap_KSplit_M00_N00_M01_N01
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default;
__host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1,
index_t N01 = 1,
index_t KSplit = 1)
: M01_(M01),
N01_(N01),
KSplit_(KSplit),
underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
{
}
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01_);
const auto N00 = math::integer_divide_ceil(N0, N01_);
const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return underlying_map_.CalculateBottomIndex(idx_top);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
else
return true;
}
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel
const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
if(M0 % M01_ == 0 && N0 % N01_ == 0)
{
return true;
}
else
{
return false;
}
}
private:
__host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01,
index_t KSplit)
{
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto N00 = math::integer_divide_ceil(N0, N01);
const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(KSplit),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor);
return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
}
index_t M01_, N01_, KSplit_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
UnderlyingMap underlying_map_;
};
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim)
{
bool is_valid = false;
const index_t m_block = c_tile_dim[Number<0>{}];
const index_t n_block = c_tile_dim[Number<1>{}];
if constexpr(CTileIdx::Size() == 2)
{
const index_t m_block_idx = c_tile_idx[Number<0>{}];
const index_t n_block_idx = c_tile_idx[Number<1>{}];
if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
{
is_valid = true;
}
}
else if constexpr(CTileIdx::Size() == 3)
{
const index_t ksplit_idx = c_tile_idx[Number<0>{}];
const index_t m_block_idx = c_tile_idx[Number<1>{}];
const index_t n_block_idx = c_tile_idx[Number<2>{}];
if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block)
{
is_valid = true;
}
ignore = ksplit_idx;
}
return is_valid;
}
} // namespace ck
#endif // UTILITY_BLOCK_TO_CTILE_MAP
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = false;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = true;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
bool PropagateNan,
bool BetaIsZero,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_blockwise
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
template <bool IsSecondCall>
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(IsSecondCall)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
};
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)p_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)p_ws_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
__device__ static void
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_ws_values_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
src_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
src_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
};
} // namespace ck
#endif
...@@ -23,75 +23,86 @@ ...@@ -23,75 +23,86 @@
* SOFTWARE. * SOFTWARE.
* *
*******************************************************************************/ *******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp" #include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation>
__global__ void __global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m,
const WorkspaceDesc_M_K workspace_desc_m_k, const InElementwiseOperation in_elementwise_op,
const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, index_t block_group_size,
index_t block_group_size, index_t num_k_block_tile_iteration,
index_t num_k_block_tile_iteration, AccDataType alpha,
const InDataType* const __restrict__ p_src_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType* const __restrict__ p_ws_values_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
(void)p_in_index_global;
(void)p_out_index_global;
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, block_group_size,
num_k_block_tile_iteration, num_k_block_tile_iteration,
p_src_global, alpha,
p_ws_values_global, p_in_value_global,
p_ws_indices_global); beta,
p_out_value_global);
} }
else else
{ {
GridwiseReduction::RunWithIndex(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, num_k_block_tile_iteration,
num_k_block_tile_iteration, alpha,
p_src_global, p_in_value_global,
p_ws_values_global, p_in_index_global,
p_ws_indices_global); beta,
p_out_value_global,
p_out_index_global);
}; };
}; };
template <typename InDataType, template <typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
...@@ -101,14 +112,13 @@ template <typename InDataType, ...@@ -101,14 +112,13 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce struct GridwiseReduction_mk_to_m_multiblock
{ {
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size, index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{ {
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType, const auto identityVal = ReduceOperation::GetIdentityValue();
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
...@@ -181,7 +191,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -181,7 +191,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
// Each block executes multiple parallel reductions on the LDS, and due to the using of constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}( static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_store = if(block_group_size == 0 && !float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc, threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_buf); out_global_val_buf);
} }
}; };
template <bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType, PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
(void)acc_elementwise_op; (void)in_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf = const auto identityVal = ReduceOperation::GetIdentityValue();
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, false>(
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, in_grid_desc_m_k,
block_local_id * reduceSizePerBlock + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t indexOffset = block_local_id * reduceSizePerBlock;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
index_t reducedTiles = 0; constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values index_t reducedTiles = 0;
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation if constexpr(HaveIndexInput)
in_elementwise_op(in_thread_val_buf(Number<offset>{}), {
in_thread_val_buf(Number<offset>{})); auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
AccDataType tmpValue = zeroVal; threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
IndexDataType tmpIndex = 0; threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue, reducedTiles++;
in_thread_val_buf[Number<offset>{}], } while(reducedTiles < num_k_block_tile_iteration);
tmpIndex, }
in_thread_idx_buf[Number<offset>{}]); else
{
index_t indexOffset = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
BlockwiseReduceWithIndex::Reduce( threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
};
indexOffset += K_BlockTileSize; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
reducedTiles++; static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
} while(reducedTiles < num_k_block_tile_iteration); if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( accu_value_buf(I) *= alpha;
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); }
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_val_store = if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
auto threadwise_workspace_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc, threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_val_buf); out_global_val_buf);
threadwise_workspace_idx_store.Run(reduced_data_desc, threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_index_buf, accu_index_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_idx_buf); out_global_idx_buf);
} }
}; };
}; };
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename AccElementwiseOperation>
__global__ void
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_global,
p_out_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::AtomicAdd,
1,
true>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
};
} // namespace ck
#endif
...@@ -37,7 +37,8 @@ ...@@ -37,7 +37,8 @@
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, beta,
p_out_global, p_out_value_global);
p_indices_global);
} }
else else
{ {
GridwiseReduction::RunWithIndices(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, p_in_index_global,
p_out_global, beta,
p_indices_global); p_out_value_global,
p_out_index_global);
}; };
}; };
...@@ -91,11 +93,9 @@ template <typename InDataType, ...@@ -91,11 +93,9 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool BetaIsZero,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global)
IndexDataType* const __restrict__ p_indices_global)
{ {
using ThreadwiseReduce = ThreadwiseReduction<AccDataType, using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K, ThreadReduceSrcDesc_M_K,
...@@ -136,21 +135,21 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -136,21 +135,21 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
(void)p_indices_global; const auto identityVal = ReduceOperation::GetIdentityValue();
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( in_grid_desc_m_k.GetElementSpaceSize(),
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0; index_t reducedLength = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, true>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
true>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; dst_global_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
dst_global_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_store = auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, OutDataType,
OutDataType, decltype(reduced_data_desc),
decltype(reduced_data_desc), OutGridDesc_M,
OutGridDesc_M, PassThroughOp,
PassThroughOp, Sequence<MThreadSliceSize>,
Sequence<MThreadSliceSize>, Sequence<0>,
Sequence<0>, 0,
0, OutDstVectorSize,
OutDstVectorSize, OutMemoryDataOperation,
InMemoryDataOperationEnum::Set, 1,
1, false>(
false>( out_grid_desc_m,
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize),
make_multi_index(thread_global_1d_id * MThreadSliceSize), PassThroughOp{});
PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k, template <bool HaveIndexInput>
const OutGridDesc_M& out_grid_desc_m, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const InElementwiseOperation& in_elementwise_op, const OutGridDesc_M& out_grid_desc_m,
const AccElementwiseOperation& acc_elementwise_op, const InElementwiseOperation& in_elementwise_op,
AccDataType alpha, const AccElementwiseOperation& acc_elementwise_op,
const InDataType* const __restrict__ p_in_global, AccDataType alpha,
AccDataType beta, const InDataType* const __restrict__ p_in_value_global,
OutDataType* const __restrict__ p_out_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType, using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
...@@ -279,14 +276,19 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -279,14 +276,19 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
...@@ -301,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -301,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t indexStart = 0; index_t indexStart = 0;
index_t reducedLength = 0; index_t reducedLength = 0;
do if constexpr(HaveIndexInput)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, auto threadwise_src_idx_load =
in_global_buf, ThreadwiseTensorSliceTransfer_v2<IndexDataType,
thread_buffer_desc, IndexDataType,
make_tuple(I0, I0), InGridDesc_M_K,
in_thread_val_buf); decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { ThreadwiseReduceWithIndex::Reduce(
// do element-wise pre-reduction operation in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK(); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
in_elementwise_op(in_thread_val_buf(Number<offset>{}), indexStart += KThreadSliceSize;
in_thread_val_buf(Number<offset>{})); reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
}
else
{
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
}); });
});
ThreadwiseReduceWithIndex::Reduce( ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
};
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, false>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
false>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; out_global_val_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
out_global_val_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
......
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename Gridwise5AryEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor>
__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
Gridwise5AryEltwise::Run(p_a_global,
p_b_global,
p_c_global,
p_d_global,
p_e_global,
p_f_global,
a_grid_desc_m,
b_grid_desc_m,
c_grid_desc_m,
d_grid_desc_m,
e_grid_desc_m,
f_grid_desc_m,
functor);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template <typename ADataType,
typename BDataType,
typename CDataType,
typename DDataType,
typename EDataType,
typename FDataType,
typename ComputeDataType,
typename AGridDesc_M,
typename BGridDesc_M,
typename CGridDesc_M,
typename DGridDesc_M,
typename EGridDesc_M,
typename FGridDesc_M,
typename ElementwiseFunctor,
index_t MPerThread,
index_t AScalarPerVector,
index_t BScalarPerVector,
index_t CScalarPerVector,
index_t DScalarPerVector,
index_t EScalarPerVector,
index_t FScalarPerVector>
struct Gridwise5AryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThrough = tensor_operation::element_wise::PassThrough;
static __device__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * MPerThread);
}
__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
const CDataType* __restrict__ p_c_global,
const DDataType* __restrict__ p_d_global,
const EDataType* __restrict__ p_e_global,
FDataType* __restrict__ p_f_global,
const AGridDesc_M a_grid_desc_m,
const BGridDesc_M b_grid_desc_m,
const CGridDesc_M c_grid_desc_m,
const DGridDesc_M d_grid_desc_m,
const EGridDesc_M e_grid_desc_m,
const FGridDesc_M f_grid_desc_m,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m.GetElementSpaceSize());
const auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m.GetElementSpaceSize());
const auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_grid_desc_m.GetElementSpaceSize());
const auto e_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_global, e_grid_desc_m.GetElementSpaceSize());
auto f_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_f_global, f_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> c_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> d_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> e_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MPerThread, true> f_thread_buf;
const auto thread_store_global_offset = CalculateElementwiseIndex();
auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
AGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
AScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{a_grid_desc_m, thread_store_global_offset};
auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
BGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
BScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{b_grid_desc_m, thread_store_global_offset};
auto c_global_load =
ThreadwiseTensorSliceTransfer_v2<CDataType,
ComputeDataType,
CGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
CScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{c_grid_desc_m, thread_store_global_offset};
auto d_global_load =
ThreadwiseTensorSliceTransfer_v2<DDataType,
ComputeDataType,
DGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
DScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{d_grid_desc_m, thread_store_global_offset};
auto e_global_load =
ThreadwiseTensorSliceTransfer_v2<EDataType,
ComputeDataType,
EGridDesc_M,
decltype(thread_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
EScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{e_grid_desc_m, thread_store_global_offset};
auto f_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
FDataType,
decltype(thread_desc_m),
FGridDesc_M,
PassThrough,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
FScalarPerVector, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
f_grid_desc_m, thread_store_global_offset, PassThrough{}};
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = c_grid_desc_m.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
index_t num_iter = M / (loop_step);
do
{
// read and process MPerThread elements
a_global_load.Run(
a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf);
b_global_load.Run(
b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf);
c_global_load.Run(
c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf);
d_global_load.Run(
d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf);
e_global_load.Run(
e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf);
static_for<0, MPerThread, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m));
functor(f_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}),
c_thread_buf(Number<offset>{}),
d_thread_buf(Number<offset>{}),
e_thread_buf(Number<offset>{}));
});
f_global_write.Run(thread_desc_m,
make_tuple(I0), // SrcSliceOriginIdx
f_thread_buf,
f_grid_desc_m,
f_global_buf);
a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index);
c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index);
d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index);
e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index);
f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index);
} while(--num_iter);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment