Commit d8f1458f authored by Jing Zhang's avatar Jing Zhang
Browse files

Merge remote-tracking branch 'origin/develop' into grouped_gemm_args_const_buff

parents 6e983ba2 40b59a63
#ifndef DEVICE_GEMM_XDL_HPP #pragma once
#define DEVICE_GEMM_XDL_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -12,6 +11,7 @@ ...@@ -12,6 +11,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "device_prop.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -408,6 +408,11 @@ struct DeviceGemmXdl ...@@ -408,6 +408,11 @@ struct DeviceGemmXdl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -515,4 +520,3 @@ struct DeviceGemmXdl ...@@ -515,4 +520,3 @@ struct DeviceGemmXdl
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
#endif
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "tensor_operation/gpu/device/gemm_specialization.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp"
#include "device_prop.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -558,6 +559,11 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -558,6 +559,11 @@ struct DeviceGemm_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
#include "device_prop.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME #ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 1 #define CK_RUN_KERNEL_AND_TIME 1
...@@ -528,6 +529,11 @@ struct DeviceGemmXdlSplitK ...@@ -528,6 +529,11 @@ struct DeviceGemmXdlSplitK
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
......
...@@ -17,7 +17,7 @@ template <typename InDataType, ...@@ -17,7 +17,7 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool NeedIndices, bool OuputIndex,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t ReduceMThreadClusterSize, ck::index_t ReduceMThreadClusterSize,
ck::index_t ReduceKThreadClusterSize, ck::index_t ReduceKThreadClusterSize,
...@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -44,8 +44,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation; AccElementwiseOperation;
static constexpr bool BetaIsZero = true;
static constexpr index_t InSrcOutDstVectorDim = static constexpr index_t InSrcOutDstVectorDim =
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
// not reduced. // not reduced.
...@@ -206,28 +204,28 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -206,28 +204,28 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<InDataType, using gridwise_reduce =
OutDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
AccDataType, OutDataType,
IndexDataType, AccDataType,
AGridDesc_M_K, IndexDataType,
BGridDesc_M, AGridDesc_M_K,
ReduceOperation, BGridDesc_M,
InElementwiseOperation, ReduceOperation,
AccElementwiseOperation, InElementwiseOperation,
false, // propagate_nan AccElementwiseOperation,
BetaIsZero, InMemoryDataOperationEnum::Set,
BlockSize, false, // propagate_nan
ReduceMThreadClusterSize, BlockSize,
ReduceKThreadClusterSize, ReduceMThreadSliceSize,
ReduceMThreadSliceSize, ReduceKThreadSliceSize,
ReduceKThreadSliceSize, InSrcOutDstVectorDim,
InSrcOutDstVectorDim, InSrcOutDstVectorSize,
InSrcOutDstVectorSize, InSrcOutDstVectorSize>;
InSrcOutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<gridwise_reduce, const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
NeedIndices, OuputIndex,
false, // don't have index input
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
...@@ -252,6 +250,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd ...@@ -252,6 +250,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd
arg.acc_element_op_, arg.acc_element_op_,
float(1), float(1),
arg.p_in_dev_, arg.p_in_dev_,
nullptr,
float(0), float(0),
arg.p_out_dev_, arg.p_out_dev_,
arg.p_out_indices_dev_); arg.p_out_indices_dev_);
......
...@@ -16,35 +16,18 @@ namespace device { ...@@ -16,35 +16,18 @@ namespace device {
template <typename InElementwiseOperation, typename AccElementwiseOperation> template <typename InElementwiseOperation, typename AccElementwiseOperation>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
virtual long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims)
{
(void)inLengths;
(void)reduceDims;
return (0);
};
virtual bool HasFurtherCall() { return (false); };
virtual std::vector<int> GetWorkspace2dLengths(const BaseArgument* argPtr)
{
(void)argPtr;
return (std::vector<int>{0, 0});
};
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) = 0; const AccElementwiseOperation acc_elementwise_op) = 0;
......
#ifndef DEVICE_REDUCE_BLOCKWISE_HPP
#define DEVICE_REDUCE_BLOCKWISE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto inPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, inPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m =
DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
nullptr,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceBlockWiseSecondCall
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices;
static_assert(
std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{});
const auto in_grid_desc_m_k =
make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths,
const std::vector<int>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
const std::vector<int>& outLengths,
const std::vector<int>& outStrides,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op)
: inLengths_(inLengths),
inStrides_(inStrides),
outLengths_(outLengths),
outStrides_(outStrides),
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
in_elementwise_op_(in_elementwise_op),
acc_elementwise_op_(acc_elementwise_op)
{
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
invariant_total_length = inLengths[0];
reduce_total_length = inLengths[1];
invariant_lowest_length = inLengths[0];
reduce_lowest_length = inLengths[1];
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * reduce_total_length * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<index_t*>(
reinterpret_cast<char*>(workspace_dev) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor(
arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_reduce_blockwise_second_call<GridwiseReduce,
NeedIndices,
InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.alpha_,
arg.in_dev_,
arg.beta_,
arg.out_dev_,
arg.workspace_indices_dev_,
arg.out_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
// cases with very small reduce_total_length should be handled by the ThreadWise method
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
(void)reduceDims;
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -14,13 +14,13 @@ namespace device { ...@@ -14,13 +14,13 @@ namespace device {
// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those // here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
// of reduce dims // of reduce dims
template <int Rank, int NumReduceDim> template <index_t Rank, int NumReduceDim>
std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{ {
static_assert(Rank <= 6, "bigger Rank size not supported!"); static_assert(Rank <= 6, "bigger Rank size not supported!");
size_t invariant_total_length = 1; long_index_t invariant_total_length = 1;
size_t reduce_total_length = 1; long_index_t reduce_total_length = 1;
constexpr int NumInvariantDim = Rank - NumReduceDim; constexpr int NumInvariantDim = Rank - NumReduceDim;
...@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths) ...@@ -35,13 +35,13 @@ std::pair<size_t, size_t> get_2d_lengths(const std::vector<int>& inLengths)
// helper functions using variadic template arguments // helper functions using variadic template arguments
template <index_t... Ns> template <index_t... Ns>
auto make_tuple_from_array_and_index_seq(const std::vector<int>& lengths, Sequence<Ns...>) auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
{ {
return make_tuple(static_cast<index_t>(lengths[Ns])...); return make_tuple(static_cast<index_t>(lengths[Ns])...);
}; };
template <index_t arraySize> template <index_t arraySize>
static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arraySize>) auto make_tuple_from_array(const std::vector<index_t>& lengths, Number<arraySize>)
{ {
static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
...@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS ...@@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector<int>& lengths, Number<arrayS
}; };
template <index_t Rank, index_t NumReduceDim> template <index_t Rank, index_t NumReduceDim>
std::vector<int> shuffle_tensor_dimensions(const std::vector<int>& origLengthsStrides, std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origLengthsStrides,
const std::vector<int>& reduceDims) const std::vector<int>& reduceDims)
{ {
std::vector<int> newLengthsStrides; std::vector<index_t> newLengthsStrides;
assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP #ifndef DEVICE_REDUCE_MULTIBLOCK_HPP
#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP #define DEVICE_REDUCE_MULTIBLOCK_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#include "device_base.hpp" #include "device_base.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_common.hpp" #include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_atomic_add.hpp" #include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_set_buffer_value.hpp" #include "gridwise_set_buffer_value.hpp"
#include "reduction_operator.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -22,8 +23,10 @@ template <typename InDataType, ...@@ -22,8 +23,10 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -32,8 +35,7 @@ template <typename InDataType, ...@@ -32,8 +35,7 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlockAtomicAdd struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -46,26 +48,40 @@ struct DeviceReduceMultiBlockAtomicAdd
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank; static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr bool support_AtomicAdd = // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
// later
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static constexpr bool out_type_compatible_with_atomic_op =
std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value; std::is_same<OutDataType, float>::value || std::is_same<OutDataType, double>::value;
static_assert(!NeedIndices && support_AtomicAdd, static_assert(
"MultiBlockAtomicAdd method can only be used with non-indiced operation and when " !use_multiblock || (use_multiblock && out_type_compatible_with_atomic_op),
"having float/double output type!"); "The OutDataType must support the atomic operation for using MultiBlock reduction");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
"MultiBlock reduction can only be used when outputing index is not required");
static_assert(
ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
"The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<int>& inStrides, const std::vector<index_t>& inStrides,
int blkGroupSize, int blkGroupSize,
int kBlockTileIterations) int numBlockTileIteration)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -109,7 +125,7 @@ struct DeviceReduceMultiBlockAtomicAdd
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M = const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
...@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -124,8 +140,8 @@ struct DeviceReduceMultiBlockAtomicAdd
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<int>& outStrides) const std::vector<index_t>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
...@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -151,31 +167,56 @@ struct DeviceReduceMultiBlockAtomicAdd
return (out_grid_desc_m_padded); return (out_grid_desc_m_padded);
}; };
static auto MakeDst1dDescriptorForBufferSet(const std::vector<index_t>& outLengths,
const std::vector<index_t>& outStrides)
{
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto length = out_grid_desc_m.GetLength(Number<0>{});
const auto pad = math::integer_least_multiple(length, BlockSize) - length;
auto out_grid_desc_m_padded =
transform_tensor_descriptor(out_grid_desc_m,
make_tuple(make_right_pad_transform(length, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
const IndexDataType* in_index_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_index_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
in_index_dev_{in_index_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
(void)out_indices_dev;
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -192,23 +233,34 @@ struct DeviceReduceMultiBlockAtomicAdd
reduce_lowest_length = inLengths_[Rank - 1]; reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1; if constexpr(use_multiblock)
while(true)
{ {
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
// we want the blkGroupSize be not more than 128 int iterations = 1;
if(testBlkGroupSize <= 128) while(true)
break; {
int testBlkGroupSize =
(reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
iterations++; // we want the blkGroupSize be not more than 128
}; if(testBlkGroupSize <= 128)
break;
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / iterations++;
(K_BlockTileSize * iterations); };
kBlockTileIterations = iterations; blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
numBlockTileIteration = iterations;
}
else
{
blkGroupSize = 1;
numBlockTileIteration =
(reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
};
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize; M_BlockTileSize * blkGroupSize;
...@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -217,27 +269,29 @@ struct DeviceReduceMultiBlockAtomicAdd
math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize;
} }
std::vector<int> inLengths_; std::vector<index_t> inLengths_;
std::vector<int> inStrides_; std::vector<index_t> inStrides_;
std::vector<int> outLengths_; std::vector<index_t> outLengths_;
std::vector<int> outStrides_; std::vector<index_t> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
const IndexDataType* in_index_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_; AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length; index_t invariant_lowest_length;
int reduce_lowest_length; index_t reduce_lowest_length;
size_t invariant_total_length; long_index_t invariant_total_length;
size_t reduce_total_length; long_index_t reduce_total_length;
index_t blkGroupSize; int blkGroupSize;
index_t kBlockTileIterations; int numBlockTileIteration;
size_t gridSize; size_t gridSize;
size_t gridSize_pre; size_t gridSize_pre;
...@@ -247,52 +301,69 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -247,52 +301,69 @@ struct DeviceReduceMultiBlockAtomicAdd
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration);
const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor( const auto out_grid_desc_m =
DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet(
arg.outLengths_, arg.outStrides_); arg.outLengths_, arg.outStrides_);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce =
GridwiseReduction_mk_to_m_multiblock_atomic_add<InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0; using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m);
using OutGridDesc_M_2 = decltype(out_grid_desc_m_2);
using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
OutMemoryDataOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel_main = kernel_reduce_multiblock<GridwiseReduce,
OutputIndex,
HaveIndexInput,
InDataType,
OutDataType,
AccDataType,
int32_t,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
const auto kernel_pre = kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M>; float avg_time = 0;
const auto kernel_main = kernel_reduce_multiblock_atocmi_add<GridwiseReduce,
InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
OutGridDesc_M,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time += launch_and_time_kernel(stream_config, if constexpr(use_multiblock)
kernel_pre, {
dim3(arg.gridSize_pre), const auto zeroVal =
dim3(BlockSize), ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
0, OutMemoryDataOperation);
out_grid_desc_m,
arg.out_dev_, const auto kernel_pre =
static_cast<OutDataType>(0.0f)); kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M_2>;
avg_time += launch_and_time_kernel(stream_config,
kernel_pre,
dim3(arg.gridSize_pre),
dim3(BlockSize),
0,
out_grid_desc_m_2,
arg.out_dev_,
zeroVal);
};
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
kernel_main, kernel_main,
...@@ -304,25 +375,34 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -304,25 +375,34 @@ struct DeviceReduceMultiBlockAtomicAdd
arg.in_elementwise_op_, arg.in_elementwise_op_,
arg.acc_elementwise_op_, arg.acc_elementwise_op_,
arg.blkGroupSize, arg.blkGroupSize,
arg.kBlockTileIterations, arg.numBlockTileIteration,
arg.alpha_, arg.alpha_,
arg.in_dev_, arg.in_dev_,
arg.out_dev_); arg.in_index_dev_,
arg.beta_,
arg.out_dev_,
arg.out_index_dev_);
return avg_time; return (avg_time);
} };
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(use_multiblock)
{
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
};
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
...@@ -347,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -347,36 +427,43 @@ struct DeviceReduceMultiBlockAtomicAdd
return (false); return (false);
}; };
if(static_cast<float>(pArg->beta_) != 0.0f)
return (false);
// To improve // To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0) if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false); return (false);
// cases with small reduce_total_length should be handled by the BlockWise method if constexpr(use_multiblock)
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize) {
return (false); // blkGroupSize of 1 should be handled by Blockwise path using
// InMemoryDataOperationEnum::Set
if(pArg->blkGroupSize == 1)
return (false);
// This is very strong restriction, but needed to avoid some failure // This is very strong restriction, but needed to avoid some failure
if(pArg->invariant_lowest_length % M_BlockTileSize != 0) if(pArg->invariant_lowest_length % M_BlockTileSize != 0)
return (false); return (false);
}
else
{
// cases with very small reduce_total_length should be handled by ThreadWise kernel
if(pArg->reduce_total_length / KThreadSliceSize < 2)
return (false);
};
return (true); return (true);
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
...@@ -388,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd ...@@ -388,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd
alpha, alpha,
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<const IndexDataType*>(in_index_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev), static_cast<IndexDataType*>(out_index_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); acc_elementwise_op);
}; };
......
#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_reduce.hpp"
#include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
bool NeedIndices,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlockPartialReduce
: public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr int MaxBlockGroupSize = 256;
long_index_t GetWorkspaceSizeInBytes(const std::vector<int> inLengths,
const std::vector<int> reduceDims) override
{
size_t invariant_total_length;
size_t reduce_total_length;
auto inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
long_index_t workspace_size = invariant_total_length * blkGroupSize;
long_index_t wsSizeInBytes =
!NeedIndices
? workspace_size * sizeof(AccDataType)
: workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int);
return (wsSizeInBytes);
};
bool HasFurtherCall() override { return (true); };
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths,
const std::vector<int>& inStrides,
int blkGroupSize,
int kBlockTileIterations)
{
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize)
{
auto ws_desc_m_k =
make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
const auto wsPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto ws_desc_m_k_padded =
transform_tensor_descriptor(ws_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, wsPad),
make_pass_through_transform(blkGroupSize)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (ws_desc_m_k_padded);
};
struct Argument : public BaseArgument
{
Argument(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_indices_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
out_indices_dev_{out_indices_dev},
workspace_dev_{workspace_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
alpha_ = type_convert<AccDataType>(alpha);
beta_ = type_convert<AccDataType>(beta);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
int iterations = 1;
while(true)
{
int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
if(testBlkGroupSize <= MaxBlockGroupSize)
break;
iterations++;
};
blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
(K_BlockTileSize * iterations);
kBlockTileIterations = iterations;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize;
size_t ws_buf2_bytes_offset = math::integer_least_multiple(
invariant_total_length * blkGroupSize * sizeof(AccDataType), 64);
if constexpr(NeedIndices)
workspace_indices_dev_ = reinterpret_cast<int*>(
reinterpret_cast<char*>(workspace_dev_) + ws_buf2_bytes_offset);
else
workspace_indices_dev_ = nullptr;
}
std::vector<int> inLengths_;
std::vector<int> inStrides_;
std::vector<int> outLengths_;
std::vector<int> outStrides_;
AccDataType alpha_;
AccDataType beta_;
const InDataType* in_dev_;
OutDataType* out_dev_;
IndexDataType* out_indices_dev_;
AccDataType* workspace_dev_;
IndexDataType* workspace_indices_dev_;
InElementwiseOperation in_elementwise_op_;
AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length;
int reduce_lowest_length;
size_t invariant_total_length;
size_t reduce_total_length;
index_t blkGroupSize;
index_t kBlockTileIterations;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor(
arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations);
const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor(
arg.invariant_total_length, arg.blkGroupSize);
using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using WorkspaceDesc_M_K = decltype(ws_desc_m_k);
using GridwiseReduce =
GridwiseReduction_mk_to_mk_multiblock_partial_reduce<InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0;
const auto kernel = kernel_partial_reduce_multiblock<GridwiseReduce,
NeedIndices,
InDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
WorkspaceDesc_M_K,
InElementwiseOperation,
AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
ws_desc_m_k,
arg.in_elementwise_op_,
arg.acc_elementwise_op_,
arg.blkGroupSize,
arg.kBlockTileIterations,
arg.in_dev_,
arg.workspace_dev_,
arg.workspace_indices_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(OutDstVectorSize != 1)
return (false);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// cases with small reduce_total_length should be handled by the BlockWise method
if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize)
return (false);
return (true);
};
std::vector<int> GetWorkspace2dLengths(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
return (
std::vector<int>{static_cast<int>(pArg->invariant_total_length), pArg->blkGroupSize});
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths,
const std::vector<int> inStrides,
const std::vector<int> outLengths,
const std::vector<int> outStrides,
const std::vector<int> reduceDims,
float alpha,
float beta,
const void* in_dev,
void* out_dev,
void* out_indices_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
outLengths,
outStrides,
reduceDims,
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op,
acc_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "device.hpp" #include "device.hpp"
#include "device_reduce.hpp" #include "device_reduce.hpp"
#include "device_reduce_common.hpp" #include "device_reduce_common.hpp"
#include "gridwise_2d_reduction_multiblock.hpp"
#include "gridwise_2d_reduction_threadwise.hpp" #include "gridwise_2d_reduction_threadwise.hpp"
namespace ck { namespace ck {
...@@ -19,22 +20,19 @@ template <typename InDataType, ...@@ -19,22 +20,19 @@ template <typename InDataType,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename OutElementwiseOperation, typename AccElementwiseOperation,
bool PropagateNan, bool PropagateNan,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInputIfOutputIndex,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutElementwiseOperation> struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, AccElementwiseOperation>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1),
"Threadwise can only be called with KThreadClusterSize be 1 !");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
...@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using IndexDataType = int32_t; using IndexDataType = int32_t;
static constexpr bool BetaIsZero = NeedIndices; static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
...@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -51,11 +49,11 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0); static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::vector<int>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<int>& inStrides) const std::vector<index_t>& inStrides)
{ {
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{}); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
...@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
static auto MakeDst1dDescriptor(const std::vector<int>& outLengths, static auto MakeDst1dDescriptor(const std::vector<index_t>& outLengths,
const std::vector<int>& outStrides) const std::vector<index_t>& outStrides)
{ {
const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{}); const auto tupleDstLengths = make_tuple_from_array(outLengths, Number<numDstDim>{});
const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number<numDstDim>{});
...@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const std::vector<int> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_indices_dev, IndexDataType* out_index_dev,
AccDataType* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
: outLengths_{outLengths}, : outLengths_{outLengths},
outStrides_{outStrides}, outStrides_{outStrides},
in_dev_{in_dev}, in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
out_indices_dev_{out_indices_dev}, out_index_dev_{out_index_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
(void)workspace_dev;
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -183,30 +177,33 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -183,30 +177,33 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
reduce_lowest_length = inLengths_[Rank - 1]; reduce_lowest_length = inLengths_[Rank - 1];
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize; M_BlockTileSize;
} }
std::vector<int> inLengths_; std::vector<index_t> inLengths_;
std::vector<int> inStrides_; std::vector<index_t> inStrides_;
std::vector<int> outLengths_; std::vector<index_t> outLengths_;
std::vector<int> outStrides_; std::vector<index_t> outStrides_;
AccDataType alpha_; AccDataType alpha_;
AccDataType beta_; AccDataType beta_;
const InDataType* in_dev_; const InDataType* in_dev_;
OutDataType* out_dev_; OutDataType* out_dev_;
IndexDataType* out_indices_dev_; IndexDataType* out_index_dev_;
InElementwiseOperation in_elementwise_op_; InElementwiseOperation in_elementwise_op_;
OutElementwiseOperation acc_elementwise_op_; AccElementwiseOperation acc_elementwise_op_;
int invariant_lowest_length; index_t invariant_lowest_length;
int reduce_lowest_length; index_t reduce_lowest_length;
size_t invariant_total_length; long_index_t invariant_total_length;
size_t reduce_total_length; long_index_t reduce_total_length;
int numBlockTileIteration;
size_t gridSize; size_t gridSize;
}; };
...@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
using InGridDesc_M_K = decltype(in_grid_desc_m_k); using InGridDesc_M_K = decltype(in_grid_desc_m_k);
using OutGridDesc_M = decltype(out_grid_desc_m); using OutGridDesc_M = decltype(out_grid_desc_m);
using GridwiseReduce = GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
PropagateNan,
BetaIsZero,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
float avg_time = 0; float avg_time = 0;
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
IndexDataType,
InGridDesc_M_K,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
InMemoryDataOperationEnum::Set,
PropagateNan,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize>;
const auto kernel = kernel_reduce_threadwise<GridwiseReduce, const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
NeedIndices, OutputIndex,
HaveIndexInput,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, AccDataType,
...@@ -252,7 +249,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -252,7 +249,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
InGridDesc_M_K, InGridDesc_M_K,
OutGridDesc_M, OutGridDesc_M,
InElementwiseOperation, InElementwiseOperation,
OutElementwiseOperation>; AccElementwiseOperation>;
avg_time = launch_and_time_kernel(stream_config, avg_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -265,9 +262,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -265,9 +262,10 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
arg.acc_elementwise_op_, arg.acc_elementwise_op_,
arg.alpha_, arg.alpha_,
arg.in_dev_, arg.in_dev_,
nullptr,
arg.beta_, arg.beta_,
arg.out_dev_, arg.out_dev_,
arg.out_indices_dev_); arg.out_index_dev_);
return (avg_time); return (avg_time);
}; };
...@@ -276,7 +274,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -276,7 +274,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
} };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -311,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -311,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
if(pArg->invariant_lowest_length % OutDstVectorSize != 0) if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false); return (false);
// TODO: remove this. Should return true, as long as this DeviceOP instance support this // cases with big reduce_total_length should be handled by Blockwise kernel
// case for bigger reduce_total_length size, we are supposed to use BlockWise method for
// better performance
if(pArg->reduce_total_length / KThreadSliceSize >= 32) if(pArg->reduce_total_length / KThreadSliceSize >= 32)
return (false); return (false);
...@@ -321,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -321,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
}; };
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<int> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<int> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> outLengths, const std::vector<index_t> outLengths,
const std::vector<int> outStrides, const std::vector<index_t> outStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
float alpha, float alpha,
float beta, float beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev,
void* out_dev, void* out_dev,
void* out_indices_dev, void* out_index_dev,
void* workspace_dev,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op) override const AccElementwiseOperation acc_elementwise_op) override
{ {
(void)in_index_dev;
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
outLengths, outLengths,
...@@ -344,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -344,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
beta, beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
static_cast<IndexDataType*>(out_indices_dev), static_cast<IndexDataType*>(out_index_dev),
static_cast<AccDataType*>(workspace_dev),
in_elementwise_op, in_elementwise_op,
acc_elementwise_op); acc_elementwise_op);
}; };
...@@ -360,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE ...@@ -360,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce<InElementwiseOperation, OutE
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceReducceThreadWise<" << BlockSize << ","; str << "DeviceReduceThreadWise<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on // clang-format on
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = false;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename GridwiseReduction,
bool NeedIndices,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation>
__global__ void
kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(!NeedIndices)
{
constexpr bool IsSecondCall = true;
GridwiseReduction::template Run<IsSecondCall>(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
}
else
{
GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
alpha,
p_in_global,
beta,
p_out_global,
p_ws_indices_global,
p_indices_global);
};
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename IndexDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
bool PropagateNan,
bool BetaIsZero,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_blockwise
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
template <bool IsSecondCall>
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
if constexpr(IsSecondCall)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
};
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)p_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)p_ws_indices_global;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
index_t indexOffset = 0;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
__device__ static void
RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation acc_elementwise_op,
AccDataType alpha,
const InDataType* const __restrict__ p_ws_values_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_global,
const IndexDataType* const __restrict__ p_ws_indices_global,
IndexDataType* const __restrict__ p_indices_global)
{
static_assert(InSrcVectorDim == 1,
"InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!");
using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType,
BlockSize,
Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
(void)in_elementwise_op;
// LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_ws_values_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto src_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
auto reduce_work_idx_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType,
MThreadSliceSize * KThreadSliceSize,
true>
in_thread_idx_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_1d_id = get_block_1d_id();
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0;
});
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize;
index_t reducedTiles = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
src_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
src_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < toReduceTiles);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
if constexpr(!BetaIsZero)
{
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0),
accu_value_buf,
out_grid_desc_m,
out_global_val_buf);
threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0),
accu_index_buf,
out_grid_desc_m,
out_global_idx_buf);
}
};
};
} // namespace ck
#endif
...@@ -23,75 +23,86 @@ ...@@ -23,75 +23,86 @@
* SOFTWARE. * SOFTWARE.
* *
*******************************************************************************/ *******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP #define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp" #include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp" #include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp" #include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "cluster_descriptor.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation>
__global__ void __global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k,
kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m,
const WorkspaceDesc_M_K workspace_desc_m_k, const InElementwiseOperation in_elementwise_op,
const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, index_t block_group_size,
index_t block_group_size, index_t num_k_block_tile_iteration,
index_t num_k_block_tile_iteration, AccDataType alpha,
const InDataType* const __restrict__ p_src_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType* const __restrict__ p_ws_values_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
(void)p_in_index_global;
(void)p_out_index_global;
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, block_group_size,
num_k_block_tile_iteration, num_k_block_tile_iteration,
p_src_global, alpha,
p_ws_values_global, p_in_value_global,
p_ws_indices_global); beta,
p_out_value_global);
} }
else else
{ {
GridwiseReduction::RunWithIndex(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
workspace_desc_m_k, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
block_group_size, num_k_block_tile_iteration,
num_k_block_tile_iteration, alpha,
p_src_global, p_in_value_global,
p_ws_values_global, p_in_index_global,
p_ws_indices_global); beta,
p_out_value_global,
p_out_index_global);
}; };
}; };
template <typename InDataType, template <typename InDataType,
typename OutDataType,
typename AccDataType, typename AccDataType,
typename IndexDataType, typename IndexDataType,
typename InGridDesc_M_K, typename InGridDesc_M_K,
typename WorkspaceDesc_M_K, typename OutGridDesc_M,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
...@@ -101,14 +112,13 @@ template <typename InDataType, ...@@ -101,14 +112,13 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce struct GridwiseReduction_mk_to_m_multiblock
{ {
static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -127,6 +137,19 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough; using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -135,43 +158,30 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size, index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{ {
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
(void)p_ws_indices_global;
(void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf = auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -221,7 +231,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -242,58 +252,97 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
reducedTiles++; reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration); } while(reducedTiles < num_k_block_tile_iteration);
// Each block executes multiple parallel reductions on the LDS, and due to the using of constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// vector_load, each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}( static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_store = if(block_group_size == 0 && !float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
false>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_store.Run(reduced_data_desc, threadwise_dst_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_buf); out_global_val_buf);
} }
}; };
template <bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const WorkspaceDesc_M_K& workspace_desc_m_k, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
const InDataType* const __restrict__ p_src_global, AccDataType alpha,
AccDataType* const __restrict__ p_ws_values_global, const InDataType* const __restrict__ p_in_value_global,
IndexDataType* const __restrict__ p_ws_indices_global) const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using BlockwiseReduceWithIndex = using BlockwiseReduceWithIndex =
PartitionedBlockwiseReductionWithIndex<AccDataType, PartitionedBlockwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
BlockSize, BlockSize,
ThreadClusterLengths_M_K, Sequence<MThreadClusterSize, KThreadClusterSize>,
ThreadClusterArrangeOrder, ThreadClusterArrangeOrder,
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -303,22 +352,24 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
(void)acc_elementwise_op; (void)in_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS // LDS
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ index_t p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto in_global_buf = const auto zeroVal = ReduceOperation::GetReductionZeroVal();
make_dynamic_buffer<AddressSpaceEnum::Global>(p_src_global,
const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(zeroVal));
auto workspace_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto workspace_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_val_buf = auto reduce_work_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize); make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -327,6 +378,7 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
IndexDataType, IndexDataType,
MThreadSliceSize * KThreadSliceSize, MThreadSliceSize * KThreadSliceSize,
...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -336,10 +388,8 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_1d_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce ...@@ -347,138 +397,239 @@ struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1]; const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>; using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, false>(
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, in_grid_desc_m_k,
block_local_id * reduceSizePerBlock + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t indexOffset = block_local_id * reduceSizePerBlock;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = zeroVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
index_t reducedTiles = 0; constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
do
{
// load the thread slice
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values index_t reducedTiles = 0;
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation if constexpr(HaveIndexInput)
in_elementwise_op(in_thread_val_buf(Number<offset>{}), {
in_thread_val_buf(Number<offset>{})); auto threadwise_src_idx_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize));
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
AccDataType tmpValue = zeroVal; threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
IndexDataType tmpIndex = 0; threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue, reducedTiles++;
in_thread_val_buf[Number<offset>{}], } while(reducedTiles < num_k_block_tile_iteration);
tmpIndex, }
in_thread_idx_buf[Number<offset>{}]); else
{
index_t indexOffset = 0;
do
{
// load the thread slice
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
// initialize the indices for the per-thread to-reduce values
in_thread_idx_buf(Number<offset>{}) =
indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
// do element-wise pre-reduction operation
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
AccDataType tmpValue = zeroVal;
IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
AccumulationWithIndex::Calculate(tmpValue,
in_thread_val_buf[Number<offset>{}],
tmpIndex,
in_thread_idx_buf[Number<offset>{}]);
});
BlockwiseReduceWithIndex::Reduce(
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
}); });
BlockwiseReduceWithIndex::Reduce( threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
AccumulationWithIndex::Calculate(
accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
});
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); indexOffset += K_BlockTileSize;
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
};
indexOffset += K_BlockTileSize; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
reducedTiles++; static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
} while(reducedTiles < num_k_block_tile_iteration); if(thread_k_cluster_id == 0)
{
// for indiced operation, acc_elementwise_op shoud do nothing
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( accu_value_buf(I) *= alpha;
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})); }
});
if(thread_k_cluster_id == 0) if(thread_k_cluster_id == 0)
{ {
auto threadwise_workspace_val_store = if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
auto threadwise_dst_load =
ThreadwiseTensorSliceTransfer_v2<OutDataType,
OutDataType,
OutGridDesc_M,
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
1,
true>(
out_grid_desc_m,
make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize));
threadwise_dst_load.Run(out_grid_desc_m,
out_global_val_buf,
reduced_data_desc,
make_tuple(I0),
priorDstValueBuf);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
});
};
auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
AccDataType, OutDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
auto threadwise_workspace_idx_store = auto threadwise_dst_idx_store =
ThreadwiseTensorSliceTransfer_v1r3<IndexDataType, ThreadwiseTensorSliceTransfer_v1r3<IndexDataType,
IndexDataType, IndexDataType,
decltype(reduced_data_desc), decltype(reduced_data_desc),
WorkspaceDesc_M_K, OutGridDesc_M,
PassThroughOp, PassThroughOp,
Sequence<MThreadSliceSize, 1>, Sequence<MThreadSliceSize>,
Sequence<0, 1>, Sequence<0>,
1, 0,
1, OutDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>( true>(
workspace_desc_m_k, out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize + make_multi_index(block_global_1d_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize),
block_local_id),
PassThroughOp{}); PassThroughOp{});
threadwise_workspace_val_store.Run(reduced_data_desc, threadwise_dst_val_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_value_buf, accu_value_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_val_buf); out_global_val_buf);
threadwise_workspace_idx_store.Run(reduced_data_desc, threadwise_dst_idx_store.Run(reduced_data_desc,
make_tuple(I0, I0), make_tuple(I0),
accu_index_buf, accu_index_buf,
workspace_desc_m_k, out_grid_desc_m,
workspace_global_idx_buf); out_global_idx_buf);
} }
}; };
}; };
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename AccElementwiseOperation>
__global__ void
kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m,
in_elementwise_op,
acc_elementwise_op,
block_group_size,
num_k_block_tile_iteration,
alpha,
p_in_global,
p_out_global);
};
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
bool PropagateNan,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct GridwiseReduction_mk_to_m_multiblock_atomic_add
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using ThreadBufferDimAccessOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadClusterArrangeOrder =
typename conditional<reorder_thread_cluster, Sequence<1, 0>, Sequence<0, 1>>::type;
static constexpr auto thread_cluster_desc =
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
ReduceOperation,
PropagateNan>;
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
PropagateNan>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op,
index_t block_group_size,
index_t num_k_block_tile_iteration,
AccDataType alpha,
const InDataType* const __restrict__ p_in_global,
OutDataType* const __restrict__ p_out_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal));
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize());
auto reduce_work_buf =
make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
const index_t blkgroup_id = block_global_id / block_group_size;
const index_t block_local_id = block_global_id % block_group_size;
const auto thread_cluster_idx =
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_k_cluster_id = thread_cluster_idx[I1];
const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k,
make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
block_local_id * reduceSizePerBlock +
thread_k_cluster_id * KThreadSliceSize));
constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize);
index_t reducedTiles = 0;
do
{
threadwise_src_load.Run(in_grid_desc_m_k,
in_global_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedTiles++;
} while(reducedTiles < num_k_block_tile_iteration);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
// Each block executes multiple parallel reductions on the LDS, and by atomic-adding its
// reduced output to the global location corresponding to each invariant dimension to get a
// consistent reduced result for that invariant dimension. due to the using of vector_load,
// each block/thread is involved into multiple invarirant dimensions.
static_for<0, MThreadSliceSize, 1>{}(
[&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
if(thread_k_cluster_id == 0)
{
acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
accu_value_buf(I) *= alpha;
}
});
if(thread_k_cluster_id == 0)
{
auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThroughOp,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
InMemoryDataOperationEnum::AtomicAdd,
1,
true>(
out_grid_desc_m,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf);
}
};
};
} // namespace ck
#endif
...@@ -37,7 +37,8 @@ ...@@ -37,7 +37,8 @@
namespace ck { namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool NeedIndices, bool OutputIndex,
bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename AccDataType,
...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -51,34 +52,35 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
const InElementwiseOperation in_elementwise_op, const InElementwiseOperation in_elementwise_op,
const AccElementwiseOperation acc_elementwise_op, const AccElementwiseOperation acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
const IndexDataType* const __restrict__ p_in_index_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_indices_global) IndexDataType* const __restrict__ p_out_index_global)
{ {
if constexpr(!NeedIndices) if constexpr(!OutputIndex)
{ {
GridwiseReduction::Run(in_grid_desc_m_k, GridwiseReduction::Run(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, beta,
p_out_global, p_out_value_global);
p_indices_global);
} }
else else
{ {
GridwiseReduction::RunWithIndices(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
out_grid_desc_m, out_grid_desc_m,
in_elementwise_op, in_elementwise_op,
acc_elementwise_op, acc_elementwise_op,
alpha, alpha,
p_in_global, p_in_value_global,
beta, p_in_index_global,
p_out_global, beta,
p_indices_global); p_out_value_global,
p_out_index_global);
}; };
}; };
...@@ -91,11 +93,9 @@ template <typename InDataType, ...@@ -91,11 +93,9 @@ template <typename InDataType,
typename ReduceOperation, typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
bool PropagateNan, bool PropagateNan,
bool BetaIsZero,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize, index_t MThreadSliceSize,
index_t KThreadSliceSize, index_t KThreadSliceSize,
index_t InSrcVectorDim, index_t InSrcVectorDim,
...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -125,10 +125,9 @@ struct GridwiseReduction_mk_to_m_threadwise
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
const AccElementwiseOperation& acc_elementwise_op, const AccElementwiseOperation& acc_elementwise_op,
AccDataType alpha, AccDataType alpha,
const InDataType* const __restrict__ p_in_global, const InDataType* const __restrict__ p_in_value_global,
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_global, OutDataType* const __restrict__ p_out_value_global)
IndexDataType* const __restrict__ p_indices_global)
{ {
using ThreadwiseReduce = ThreadwiseReduction<AccDataType, using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K, ThreadReduceSrcDesc_M_K,
...@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -136,14 +135,14 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
(void)p_indices_global;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_val_buf =
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf; in_thread_buf;
...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0; index_t reducedLength = 0;
do do
{ {
threadwise_src_load.Run(in_grid_desc_m_k, threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_buf, in_global_val_buf,
thread_buffer_desc, thread_buffer_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
in_thread_buf); in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation // do element-wise pre-reduction operation
...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, true>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
true>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; dst_global_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
dst_global_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_store = auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, OutDataType,
OutDataType, decltype(reduced_data_desc),
decltype(reduced_data_desc), OutGridDesc_M,
OutGridDesc_M, PassThroughOp,
PassThroughOp, Sequence<MThreadSliceSize>,
Sequence<MThreadSliceSize>, Sequence<0>,
Sequence<0>, 0,
0, OutDstVectorSize,
OutDstVectorSize, OutMemoryDataOperation,
InMemoryDataOperationEnum::Set, 1,
1, false>(
false>( out_grid_desc_m,
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize),
make_multi_index(thread_global_1d_id * MThreadSliceSize), PassThroughOp{});
PassThroughOp{});
threadwise_dst_store.Run( threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
__device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k, template <bool HaveIndexInput>
const OutGridDesc_M& out_grid_desc_m, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const InElementwiseOperation& in_elementwise_op, const OutGridDesc_M& out_grid_desc_m,
const AccElementwiseOperation& acc_elementwise_op, const InElementwiseOperation& in_elementwise_op,
AccDataType alpha, const AccElementwiseOperation& acc_elementwise_op,
const InDataType* const __restrict__ p_in_global, AccDataType alpha,
AccDataType beta, const InDataType* const __restrict__ p_in_value_global,
OutDataType* const __restrict__ p_out_global, const IndexDataType* const __restrict__ p_in_index_global,
IndexDataType* const __restrict__ p_indices_global) AccDataType beta,
OutDataType* const __restrict__ p_out_value_global,
IndexDataType* const __restrict__ p_out_index_global)
{ {
using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType, using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
IndexDataType, IndexDataType,
...@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -281,12 +278,17 @@ struct GridwiseReduction_mk_to_m_threadwise
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_val_buf =
p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert<InDataType>(zeroVal)); make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_indices_global, out_grid_desc_m.GetElementSpaceSize()); p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_val_buf; in_thread_val_buf;
...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType, auto threadwise_src_val_load =
AccDataType, ThreadwiseTensorSliceTransfer_v2<InDataType,
InGridDesc_M_K, AccDataType,
decltype(thread_buffer_desc), InGridDesc_M_K,
ThreadBufferLengths, decltype(thread_buffer_desc),
ThreadBufferDimAccessOrder, ThreadBufferLengths,
InSrcVectorDim, ThreadBufferDimAccessOrder,
InSrcVectorSize, InSrcVectorDim,
1, InSrcVectorSize,
false>( 1,
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t indexStart = 0; index_t indexStart = 0;
index_t reducedLength = 0; index_t reducedLength = 0;
do if constexpr(HaveIndexInput)
{ {
threadwise_src_load.Run(in_grid_desc_m_k, auto threadwise_src_idx_load =
in_global_buf, ThreadwiseTensorSliceTransfer_v2<IndexDataType,
thread_buffer_desc, IndexDataType,
make_tuple(I0, I0), InGridDesc_M_K,
in_thread_val_buf); decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
threadwise_src_idx_load.Run(in_grid_desc_m_k,
in_global_idx_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { ThreadwiseReduceWithIndex::Reduce(
// do element-wise pre-reduction operation in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK(); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
in_elementwise_op(in_thread_val_buf(Number<offset>{}), indexStart += KThreadSliceSize;
in_thread_val_buf(Number<offset>{})); reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
}
else
{
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
}); });
});
ThreadwiseReduceWithIndex::Reduce( ThreadwiseReduceWithIndex::Reduce(
in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
};
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
if constexpr(!BetaIsZero) if(!float_equal_zero{}(beta))
{ {
if(!float_equal_zero{}(beta)) auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
{ OutDataType,
auto threadwise_dst_load = OutGridDesc_M,
ThreadwiseTensorSliceTransfer_v2<OutDataType, decltype(reduced_data_desc),
OutDataType, Sequence<MThreadSliceSize>,
OutGridDesc_M, Sequence<0>,
decltype(reduced_data_desc), 0,
Sequence<MThreadSliceSize>, 1,
Sequence<0>, 1,
0, false>(
1, out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
1,
false>( StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); priorDstValue_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> threadwise_dst_load.Run(out_grid_desc_m,
priorDstValue_buf; out_global_val_buf,
reduced_data_desc,
threadwise_dst_load.Run(out_grid_desc_m, make_tuple(I0),
out_global_val_buf, priorDstValue_buf);
reduced_data_desc,
make_tuple(I0), static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
priorDstValue_buf); accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
});
};
}; };
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise
Sequence<0>, Sequence<0>,
0, 0,
OutDstVectorSize, OutDstVectorSize,
InMemoryDataOperationEnum::Set, OutMemoryDataOperation,
1, 1,
false>( false>(
out_grid_desc_m, out_grid_desc_m,
......
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP #pragma once
#define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_dl_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AK0M0M1K1GridDesc, typename AGridDesc_K0_M0_M1_K1,
typename BK0N0N1K1GridDesc, typename BGridDesc_K0_N0_N1_K1,
typename CM0M10M11N0N10N11GridDesc, typename CGridDesc_M0_M10_M11_N0_N10_N11,
typename CBlockIdToM0N0BlockClusterAdaptor, typename Block2CTileMap,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dlops_v1r3( kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, const Block2CTileMap block_2_ctile_map)
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -43,10 +43,10 @@ __global__ void ...@@ -43,10 +43,10 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_k0_m0_m1_k1_grid_desc, a_grid_desc_k0_m0_m1_k1,
b_k0_n0_n1_k1_grid_desc, b_grid_desc_k0_n0_n1_k1,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_desc_m0_m10_m11_n0_n10_n11,
cblockid_to_m0_n0_block_cluster_adaptor, block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -56,12 +56,12 @@ template <index_t BlockSize, ...@@ -56,12 +56,12 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AGridDesc_K0_M_K1,
typename BK0NK1GridDesc, typename BGridDesc_K0_N_K1,
typename CMNGridDesc, typename CGridDesc_M_N,
index_t MPerBlockM1, index_t MPerBlock,
index_t NPerBlockN1, index_t NPerBlock,
index_t KPerBlock, index_t K0PerBlock,
index_t M1PerThreadM111, index_t M1PerThreadM111,
index_t N1PerThreadN111, index_t N1PerThreadN111,
index_t KPerThread, index_t KPerThread,
...@@ -83,13 +83,8 @@ template <index_t BlockSize, ...@@ -83,13 +83,8 @@ template <index_t BlockSize,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector>
typename AGridStepHacks, struct GridwiseGemmDl_km_kn_mn_v1r3
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
...@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CMNGridDesc& c_m_n_grid_desc) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{ {
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
return has_main_k_block_loop; return has_main_k_block_loop;
} }
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{ {
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop; return has_double_tail_k_block_loop;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
{ {
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{}; const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc = const auto a_grid_desc_k0_m0_m1_k1 =
transform_tensor_descriptor(a_k0_m_k1_grid_desc, transform_tensor_descriptor(a_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc; return a_grid_desc_k0_m0_m1_k1;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{ {
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{}; const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc = const auto b_grid_desc_k0_n0_n1_k1 =
transform_tensor_descriptor(b_k0_n_k1_grid_desc, transform_tensor_descriptor(b_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc; return b_grid_desc_k0_n0_n1_k1;
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlockN1>{}; constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
...@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_m_n_grid_desc, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc; return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
const auto N = c_m_n_grid_desc.GetLength(I1); c_grid_desc_m_n);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); using CGridDesc_M0_M10_M11_N0_N10_N11 =
using CBlockIdToM0N0BlockClusterAdaptor = decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
...@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, const Block2CTileMap& block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx = const auto c_m0_n0_block_cluster_idx =
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR // HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
if(!block_2_ctile_map.ValidCTileIndex(
make_tuple(im0, in0),
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
{
return;
}
// TODO: change this. I think it needs multi-dimensional alignment // TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() && b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!"); "wrong!");
...@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>, Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc), remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
...@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, false,
true>(a_k0_m0_m1_k1_grid_desc, true>(a_grid_desc_k0_m0_m1_k1,
make_multi_index(0, im0, 0, 0), make_multi_index(0, im0, 0, 0),
a_k0_m0_m1_k1_block_desc, a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize, BlockSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>, Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc), remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
...@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, false,
true>(b_k0_n0_n1_k1_grid_desc, true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0), make_multi_index(0, in0, 0, 0),
b_k0_n0_n1_k1_block_desc, b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS // b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
...@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc, // Initialize C
decltype(c_m10_m11_n10_n11_thread_desc), c_thread_buf.Clear();
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size, p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step, a_block_slice_copy_step);
AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
a_block_even_buf, a_block_even_buf,
b_block_even_buf, b_block_even_buf,
c_thread_buf); c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step, a_block_slice_copy_step);
AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
__syncthreads(); block_sync_lds();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
__syncthreads(); block_sync_lds();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
...@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
{ {
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
make_tuple(I1, make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{}, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
...@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1], c_m10_m11_n10_n11_thread_tensor_lengths[I1],
...@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc, true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
make_multi_index(im0, make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0, in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc, ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_desc_m0_m10_m11_n0_n10_n11,
c_grid_buf, c_grid_buf);
CGridStepHacks{});
} }
} }
}; };
} // namespace ck } // namespace ck
#endif
#pragma once #pragma once
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
......
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP #pragma once
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "math.hpp" #include "math.hpp"
...@@ -25,9 +23,9 @@ template <typename FloatA, ...@@ -25,9 +23,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -124,9 +122,9 @@ template <typename FloatA, ...@@ -124,9 +122,9 @@ template <typename FloatA,
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{ {
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
...@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP #pragma once
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -325,7 +325,7 @@ struct DynamicBuffer ...@@ -325,7 +325,7 @@ struct DynamicBuffer
{ {
if(is_valid_element) if(is_valid_element)
{ {
atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x); atomic_add(c_style_pointer_cast<X*>(&p_data_[i]), x);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment