Commit 644df335 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into gemm_layernorm_instance

parents d99640ab 7494c1c6
...@@ -13,10 +13,16 @@ namespace ck { ...@@ -13,10 +13,16 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
struct DeviceReduce : public BaseOperator struct DeviceReduce : public BaseOperator
{ {
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
...@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator ...@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const std::array<index_t, NumOutDim> outLengths, const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides, const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
...@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator ...@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation, typename InElementwiseOperation,
typename AccElementwiseOperation> typename AccElementwiseOperation,
using DeviceReducePtr = std::unique_ptr< bool PropagateNan,
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>; bool OutputIndex>
using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator ...@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension // @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension // @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied // @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling // @param[in] alpha double type value
// value as type AccDataType // @param[in] beta double type value
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] in_dev Typeless const pointer in device memory storing the input // @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor // tensor
// @param out_dev Typeless pointer in device memory storing the output tensor // @param out_dev Typeless pointer in device memory storing the output tensor
...@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator ...@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer(const std::vector<index_t> inLengths, MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const void* alpha, double alpha,
const void* beta, double beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -26,10 +26,10 @@ template <typename InDataTypeTuple, ...@@ -26,10 +26,10 @@ template <typename InDataTypeTuple,
index_t NPerThread, index_t NPerThread,
typename InScalarPerVectorSeq, typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq> typename OutScalarPerVectorSeq>
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple, struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple, OutDataTypeTuple,
ElementwiseOperation, ElementwiseOperation,
NumDim_m + NumDim_n> NumDim_m + NumDim_n>
{ {
static constexpr index_t NumDim = NumDim_m + NumDim_n; static constexpr index_t NumDim = NumDim_m + NumDim_n;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp" #include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -25,8 +25,8 @@ template <typename InDataTypeTuple, ...@@ -25,8 +25,8 @@ template <typename InDataTypeTuple,
index_t MPerThread, index_t MPerThread,
typename InScalarPerVectorSeq, typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq> typename OutScalarPerVectorSeq>
struct DeviceElementwise struct DeviceElementwiseImpl
: public DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim> : public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{ {
static constexpr int NumInput = InDataTypeTuple::Size(); static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size(); static constexpr int NumOutput = OutDataTypeTuple::Size();
......
...@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl ...@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
XElementwiseOperation x_elementwise_op, XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op, YElementwiseOperation y_elementwise_op,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: epsilon_(epsilon), : p_gamma_(p_gamma),
p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
x_elementwise_op_(x_elementwise_op), x_elementwise_op_(x_elementwise_op),
y_elementwise_op_(y_elementwise_op) y_elementwise_op_(y_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
for(int i = 0; i < NumInput; i++) for(int i = 0; i < NumInput; i++)
...@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl ...@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const std::array<const void*, NumInput> in_dev_buffers, const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
...@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
return launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
......
...@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
...@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for<0, NumReduction, 1>{}([&](auto I) { static_for<0, NumReduction, 1>{}([&](auto I) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>; using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
flag = flag =
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation, flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value; OutDataType>::value;
}); });
return flag; return flag;
...@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths, const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas, const std::array<double, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas, const std::array<double, NumReduction>& betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers, const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
...@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++) for(size_t i = 0; i < NumReduction; i++)
{ {
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]); alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]); beta_values_(i) = static_cast<AccDataType>(betas[i]);
}; };
in_dev_ = static_cast<const InDataType*>(in_dev); in_dev_ = static_cast<const InDataType*>(in_dev);
...@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank, ...@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths, const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims, const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas, const std::array<double, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas, const std::array<double, NumReduction>& betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers, const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
...@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++) for(size_t i = 0; i < NumReduction; i++)
{ {
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]); alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]); beta_values_(i) = static_cast<AccDataType>(betas[i]);
}; };
in_dev_ = static_cast<const InDataType*>(in_dev); in_dev_ = static_cast<const InDataType*>(in_dev);
...@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank, ...@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths, const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray, const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas, const std::array<double, NumReduction> alphas,
const std::array<const void*, NumReduction> betas, const std::array<double, NumReduction> betas,
const void* in_dev, const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers, const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple,
......
...@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op, AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon, double epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
YDataType* p_y) YDataType* p_y)
: epsilon_(epsilon), : p_x_(p_x),
p_x_(p_x),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
p_y_(p_y), p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op) acc_elementwise_op_(acc_elementwise_op)
{ {
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims); xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims); yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
...@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides, const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon, double epsilon,
const void* p_x, const void* p_x,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
......
...@@ -40,8 +40,16 @@ template <typename InDataType, ...@@ -40,8 +40,16 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceMultiBlock struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation> AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock ...@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static constexpr bool use_multiblock = static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation, static_assert(ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value, OutDataType>::value,
"The OutDataType must support the specified OutMemoryDataOperation!"); "The OutDataType must support the specified OutMemoryDataOperation!");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex), static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
...@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock ...@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
const IndexDataType* in_index_dev, const IndexDataType* in_index_dev,
OutDataType* out_dev, OutDataType* out_dev,
...@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock ...@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
......
...@@ -35,8 +35,17 @@ template <typename InDataType, ...@@ -35,8 +35,17 @@ template <typename InDataType,
index_t InSrcVectorDim, index_t InSrcVectorDim,
index_t InSrcVectorSize, index_t InSrcVectorSize,
index_t OutDstVectorSize> index_t OutDstVectorSize>
struct DeviceReduceThreadWise struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation> AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
...@@ -156,8 +165,8 @@ struct DeviceReduceThreadWise ...@@ -156,8 +165,8 @@ struct DeviceReduceThreadWise
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
IndexDataType* out_index_dev, IndexDataType* out_index_dev,
...@@ -332,8 +341,8 @@ struct DeviceReduceThreadWise ...@@ -332,8 +341,8 @@ struct DeviceReduceThreadWise
const std::array<index_t, NumDstDim> outLengths, const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides, const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims, const std::array<int, NumReduceDim> reduceDims,
float alpha, double alpha,
float beta, double beta,
const void* in_dev, const void* in_dev,
const void* in_index_dev, const void* in_index_dev,
void* out_dev, void* out_dev,
......
...@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument(const std::vector<index_t> inLengths, Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType alpha, double alpha,
AccDataType beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op) AccElementwiseOp acc_elementwise_op)
: alpha_{alpha}, : in_dev_{in_dev},
beta_{beta},
in_dev_{in_dev},
out_dev_{out_dev}, out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
if(Rank != inLengths.size() || Rank != inStrides.size() || if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size()) NumReduceDim != reduceDims.size())
{ {
...@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static auto MakeArgument(const std::vector<index_t> inLengths, static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const AccDataType alpha, double alpha,
const AccDataType beta, double beta,
const InDataType* in_dev, const InDataType* in_dev,
OutDataType* out_dev, OutDataType* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
...@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths, std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides, const std::vector<index_t> inStrides,
const std::vector<int> reduceDims, const std::vector<int> reduceDims,
const void* alpha, double alpha,
const void* beta, double beta,
const void* in_dev, const void* in_dev,
void* out_dev, void* out_dev,
InElementwiseOp in_elementwise_op, InElementwiseOp in_elementwise_op,
...@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return std::make_unique<Argument>(inLengths, return std::make_unique<Argument>(inLengths,
inStrides, inStrides,
reduceDims, reduceDims,
*static_cast<const AccDataType*>(alpha), alpha,
*static_cast<const AccDataType*>(beta), beta,
static_cast<const InDataType*>(in_dev), static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev), static_cast<OutDataType*>(out_dev),
in_elementwise_op, in_elementwise_op,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,16 +24,17 @@ template <typename EmbType, ...@@ -24,16 +24,17 @@ template <typename EmbType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename OutType, typename OutType,
typename EmbElementwiseOperation,
ck::index_t BlockSize, ck::index_t BlockSize,
ck::index_t DimClusterSize, ck::index_t DimClusterSize,
ck::index_t RowClusterSize, ck::index_t RowClusterSize,
ck::index_t DimPerBlock, ck::index_t DimPerBlock,
ck::index_t RowPerBlock, ck::index_t RowPerBlock,
ck::index_t DimThreadSize, ck::index_t DimThreadSize,
ck::index_t RowVectorSize> ck::index_t RowVectorSize,
struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator ck::index_t NumEmbeddings>
struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
{ {
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows) static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows)); return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
...@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator ...@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(OutType* p_out, Argument(OutType* p_out,
const EmbType* p_emb_a, const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const EmbType* p_emb_b, const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
const BetaDataType* p_beta, const BetaDataType* p_beta,
const ck::index_t NumRows,
const ck::index_t EmbeddingDim, const ck::index_t EmbeddingDim,
const ck::index_t IndexLength, const ck::index_t IndexLength,
const AccDataType epsilon) const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
: p_out_(p_out), : p_out_(p_out),
p_emb_a_(p_emb_a), p_embs_(p_embs),
p_emb_b_(p_emb_b), p_indexs_(p_indexs),
p_emb_c_(p_emb_c),
p_index_a_(p_index_a),
p_index_b_(p_index_b),
p_index_c_(p_index_c),
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_beta_(p_beta), p_beta_(p_beta),
NumRows_(NumRows),
EmbeddingDim_(EmbeddingDim), EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength), IndexLength_(IndexLength),
epsilon_(epsilon) epsilon_(epsilon),
emb_elementwise_op_(emb_elementwise_op)
{ {
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
} }
OutType* p_out_; OutType* p_out_;
const EmbType* p_emb_a_; ck::Array<EmbType*, NumEmbeddings> p_embs_;
const EmbType* p_emb_b_; ck::Array<IndexType*, NumEmbeddings> p_indexs_;
const EmbType* p_emb_c_;
const IndexType* p_index_a_;
const IndexType* p_index_b_;
const IndexType* p_index_c_;
const GammaDataType* p_gamma_; const GammaDataType* p_gamma_;
const BetaDataType* p_beta_; const BetaDataType* p_beta_;
ck::index_t NumRows_;
ck::index_t EmbeddingDim_; ck::index_t EmbeddingDim_;
ck::index_t IndexLength_; ck::index_t IndexLength_;
AccDataType epsilon_; AccDataType epsilon_;
EmbElementwiseOperation emb_elementwise_op_;
size_t grid_size_; size_t grid_size_;
}; };
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(void* p_out, std::unique_ptr<BaseArgument>
const void* p_emb_a, MakeArgumentPointer(void* p_out,
const void* p_emb_b, const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const void* p_emb_c, const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const void* p_index_a, const void* p_gamma,
const void* p_index_b, const void* p_beta,
const void* p_index_c, ck::index_t EmbeddingDim,
const void* p_gamma, ck::index_t IndexLength,
const void* p_beta, const AccDataType epsilon,
ck::index_t NumRows, const EmbElementwiseOperation emb_elementwise_op)
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon)
{ {
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out), return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
reinterpret_cast<const EmbType*>(p_emb_a), p_embs,
reinterpret_cast<const EmbType*>(p_emb_b), p_indexs,
reinterpret_cast<const EmbType*>(p_emb_c),
reinterpret_cast<const IndexType*>(p_index_a),
reinterpret_cast<const IndexType*>(p_index_b),
reinterpret_cast<const IndexType*>(p_index_c),
reinterpret_cast<const GammaDataType*>(p_gamma), reinterpret_cast<const GammaDataType*>(p_gamma),
reinterpret_cast<const BetaDataType*>(p_beta), reinterpret_cast<const BetaDataType*>(p_beta),
NumRows,
EmbeddingDim, EmbeddingDim,
IndexLength, IndexLength,
epsilon); epsilon,
emb_elementwise_op);
} }
using GridwiseSparseEmbedding = using GridwiseSparseEmbedding =
GridwiseSparseEmbedding3ForwardLayernorm<EmbType, GridwiseSparseEmbeddingsForwardLayernorm<EmbType,
IndexType, IndexType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
OutType, OutType,
decltype(MakeOutputDescriptor(1, 1)), decltype(MakeOutputDescriptor(1, 1)),
EmbElementwiseOperation,
BlockSize, BlockSize,
DimClusterSize, DimClusterSize,
RowClusterSize, RowClusterSize,
DimPerBlock, DimPerBlock,
RowPerBlock, RowPerBlock,
DimThreadSize, DimThreadSize,
RowVectorSize>; RowVectorSize,
NumEmbeddings>;
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
...@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator ...@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{ {
auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_); auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
const auto kernel_main = const auto kernel_main =
kernel_sparse_embedding3_forward_layernorm<GridwiseSparseEmbedding, kernel_sparse_embeddings_forward_layernorm<GridwiseSparseEmbedding,
EmbType, EmbType,
IndexType, IndexType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
OutType, OutType,
decltype(out_desc)>; decltype(out_desc),
EmbElementwiseOperation,
NumEmbeddings>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
kernel_main, kernel_main,
...@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator ...@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_out_, arg.p_out_,
arg.p_emb_a_, arg.p_embs_,
arg.p_emb_b_, arg.p_indexs_,
arg.p_emb_c_,
arg.p_index_a_,
arg.p_index_b_,
arg.p_index_c_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
out_desc, out_desc,
arg.epsilon_); arg.epsilon_,
arg.emb_elementwise_op_);
return (avg_time); return (avg_time);
} }
...@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator ...@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static bool IsSupportedArgument(const Argument* p_arg) static bool IsSupportedArgument(const Argument* p_arg)
{ {
return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0); return (RowPerBlock == p_arg->EmbeddingDim_);
} }
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
...@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator ...@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" << str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" <<
DimClusterSize << "x" << RowClusterSize << "_" << DimClusterSize << "x" << RowClusterSize << "_" <<
DimPerBlock << "x" << RowPerBlock << "_" << DimPerBlock << "x" << RowPerBlock << "_" <<
DimThreadSize << "x" << RowVectorSize; DimThreadSize << "x" << RowVectorSize;
......
...@@ -172,6 +172,42 @@ struct AddAdd ...@@ -172,6 +172,42 @@ struct AddAdd
} }
}; };
// C = A * B
// E = (C + D0) x D1
struct AddMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (type_convert<half_t>(c) + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = (c + d0) * d1;
e = y;
}
};
// C = A * B // C = A * B
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
...@@ -278,6 +314,40 @@ struct Normalize ...@@ -278,6 +314,40 @@ struct Normalize
double epsilon_; double epsilon_;
}; };
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& variance,
const T3& gamma,
const T4& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
type_convert<T2>(gamma) +
type_convert<T2>(beta);
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
template <typename Y, typename X> template <typename Y, typename X>
struct UnaryTypeConvert; struct UnaryTypeConvert;
......
...@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t idx_M01 = idx_M0 % M01_; index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt); idx_N0_M01_local / M01_adapt);
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave;
// 1-stage prefetch
template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{
// TODO: improve applicability
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
// sync for Load threads()
block_sync_lds();
// global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// sync with math threads()
block_sync_lds();
// LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
}
}
};
template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmMathWave;
// 1- stage prefetch
template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment