Commit 644df335 authored by rocking's avatar rocking
Browse files

Merge branch 'develop' into gemm_layernorm_instance

parents d99640ab 7494c1c6
......@@ -13,10 +13,16 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank,
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation>
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
struct DeviceReduce : public BaseOperator
{
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
......@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
......@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank,
template <typename InDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation>
using DeviceReducePtr = std::unique_ptr<
DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>>;
typename AccElementwiseOperation,
bool PropagateNan,
bool OutputIndex>
using DeviceReducePtr = std::unique_ptr<DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] alpha double type value
// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
......@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
double alpha,
double beta,
const void* in_dev,
void* out_dev,
InElementwiseOp in_elementwise_op,
......
......@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -26,10 +26,10 @@ template <typename InDataTypeTuple,
index_t NPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise : public DeviceElementwiseBase<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
struct DeviceElementwise2dImpl : public DeviceElementwise<InDataTypeTuple,
OutDataTypeTuple,
ElementwiseOperation,
NumDim_m + NumDim_n>
{
static constexpr index_t NumDim = NumDim_m + NumDim_n;
......
......@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -25,8 +25,8 @@ template <typename InDataTypeTuple,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwise
: public DeviceElementwiseBase<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
......
......@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> reduceDims,
XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op,
AccDataType epsilon,
double epsilon,
const std::array<const void*, NumInput> in_dev_buffers,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: epsilon_(epsilon),
p_gamma_(p_gamma),
: p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
x_elementwise_op_(x_elementwise_op),
y_elementwise_op_(y_elementwise_op)
{
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
for(int i = 0; i < NumInput; i++)
......@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
double epsilon,
const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma,
const void* p_beta,
......
......@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
......@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
......
......@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
......@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for<0, NumReduction, 1>{}([&](auto I) {
using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
flag =
flag && ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value;
flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value;
});
return flag;
......@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas,
const std::array<double, NumReduction>& alphas,
const std::array<double, NumReduction>& betas,
const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++)
{
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = static_cast<AccDataType>(betas[i]);
};
in_dev_ = static_cast<const InDataType*>(in_dev);
......@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const std::array<double, NumReduction> alphas,
const std::array<double, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......
......@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim>& outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
const std::array<int, NumReduceDim>& reduceDims,
const std::array<const void*, NumReduction>& alphas,
const std::array<const void*, NumReduction>& betas,
const std::array<double, NumReduction>& alphas,
const std::array<double, NumReduction>& betas,
const void* in_dev,
const std::array<void*, NumReduction>& out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for(size_t i = 0; i < NumReduction; i++)
{
alpha_values_(i) = *static_cast<const AccDataType*>(alphas[i]);
beta_values_(i) = *static_cast<const AccDataType*>(betas[i]);
alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
beta_values_(i) = static_cast<AccDataType>(betas[i]);
};
in_dev_ = static_cast<const InDataType*>(in_dev);
......@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const std::array<index_t, NumOutputDim> outLengths,
const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
const std::array<int, NumReduceDim> reduceDims,
const std::array<const void*, NumReduction> alphas,
const std::array<const void*, NumReduction> betas,
const std::array<double, NumReduction> alphas,
const std::array<double, NumReduction> betas,
const void* in_dev,
const std::array<void*, NumReduction> out_dev_buffers,
const InElementwiseOperationTuple in_elementwise_op_tuple,
......
......@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon,
double epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: epsilon_(epsilon),
p_x_(p_x),
: p_x_(p_x),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
acc_elementwise_op_(acc_elementwise_op)
{
epsilon_ = static_cast<AccDataType>(epsilon);
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
......@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
double epsilon,
const void* p_x,
const void* p_gamma,
const void* p_beta,
......
......@@ -40,8 +40,16 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceMultiBlock
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static constexpr bool use_multiblock =
(OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value,
static_assert(ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
OutDataType>::value,
"The OutDataType must support the specified OutMemoryDataOperation!");
static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
......@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const InDataType* in_dev,
const IndexDataType* in_index_dev,
OutDataType* out_dev,
......@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
......
......@@ -35,8 +35,17 @@ template <typename InDataType,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize>
struct DeviceReduceThreadWise
: public DeviceReduce<Rank, NumReduceDim, InElementwiseOperation, AccElementwiseOperation>
struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
......@@ -156,8 +165,8 @@ struct DeviceReduceThreadWise
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const InDataType* in_dev,
OutDataType* out_dev,
IndexDataType* out_index_dev,
......@@ -332,8 +341,8 @@ struct DeviceReduceThreadWise
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
float alpha,
float beta,
double alpha,
double beta,
const void* in_dev,
const void* in_index_dev,
void* out_dev,
......
......@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<index_t> reduceDims,
AccDataType alpha,
AccDataType beta,
double alpha,
double beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
: alpha_{alpha},
beta_{beta},
in_dev_{in_dev},
: in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op}
{
alpha_ = static_cast<AccDataType>(alpha);
beta_ = static_cast<AccDataType>(beta);
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
......@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const AccDataType alpha,
const AccDataType beta,
double alpha,
double beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
......@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const void* alpha,
const void* beta,
double alpha,
double beta,
const void* in_dev,
void* out_dev,
InElementwiseOp in_elementwise_op,
......@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return std::make_unique<Argument>(inLengths,
inStrides,
reduceDims,
*static_cast<const AccDataType*>(alpha),
*static_cast<const AccDataType*>(beta),
alpha,
beta,
static_cast<const InDataType*>(in_dev),
static_cast<OutDataType*>(out_dev),
in_elementwise_op,
......
......@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding3_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp"
namespace ck {
namespace tensor_operation {
......@@ -24,16 +24,17 @@ template <typename EmbType,
typename BetaDataType,
typename AccDataType,
typename OutType,
typename EmbElementwiseOperation,
ck::index_t BlockSize,
ck::index_t DimClusterSize,
ck::index_t RowClusterSize,
ck::index_t DimPerBlock,
ck::index_t RowPerBlock,
ck::index_t DimThreadSize,
ck::index_t RowVectorSize>
struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
ck::index_t RowVectorSize,
ck::index_t NumEmbeddings>
struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
{
static auto MakeOutputDescriptor(const index_t index_length, const index_t rows)
{
return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows));
......@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct Argument : public BaseArgument
{
Argument(OutType* p_out,
const EmbType* p_emb_a,
const EmbType* p_emb_b,
const EmbType* p_emb_c,
const IndexType* p_index_a,
const IndexType* p_index_b,
const IndexType* p_index_c,
const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
const ck::index_t NumRows,
const ck::index_t EmbeddingDim,
const ck::index_t IndexLength,
const AccDataType epsilon)
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
: p_out_(p_out),
p_emb_a_(p_emb_a),
p_emb_b_(p_emb_b),
p_emb_c_(p_emb_c),
p_index_a_(p_index_a),
p_index_b_(p_index_b),
p_index_c_(p_index_c),
p_embs_(p_embs),
p_indexs_(p_indexs),
p_gamma_(p_gamma),
p_beta_(p_beta),
NumRows_(NumRows),
EmbeddingDim_(EmbeddingDim),
IndexLength_(IndexLength),
epsilon_(epsilon)
epsilon_(epsilon),
emb_elementwise_op_(emb_elementwise_op)
{
grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize;
}
OutType* p_out_;
const EmbType* p_emb_a_;
const EmbType* p_emb_b_;
const EmbType* p_emb_c_;
const IndexType* p_index_a_;
const IndexType* p_index_b_;
const IndexType* p_index_c_;
ck::Array<EmbType*, NumEmbeddings> p_embs_;
ck::Array<IndexType*, NumEmbeddings> p_indexs_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
ck::index_t NumRows_;
ck::index_t EmbeddingDim_;
ck::index_t IndexLength_;
AccDataType epsilon_;
EmbElementwiseOperation emb_elementwise_op_;
size_t grid_size_;
};
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(void* p_out,
const void* p_emb_a,
const void* p_emb_b,
const void* p_emb_c,
const void* p_index_a,
const void* p_index_b,
const void* p_index_c,
const void* p_gamma,
const void* p_beta,
ck::index_t NumRows,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon)
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_out,
const ck::Array<EmbType*, NumEmbeddings>& p_embs,
const ck::Array<IndexType*, NumEmbeddings>& p_indexs,
const void* p_gamma,
const void* p_beta,
ck::index_t EmbeddingDim,
ck::index_t IndexLength,
const AccDataType epsilon,
const EmbElementwiseOperation emb_elementwise_op)
{
return std::make_unique<Argument>(reinterpret_cast<OutType*>(p_out),
reinterpret_cast<const EmbType*>(p_emb_a),
reinterpret_cast<const EmbType*>(p_emb_b),
reinterpret_cast<const EmbType*>(p_emb_c),
reinterpret_cast<const IndexType*>(p_index_a),
reinterpret_cast<const IndexType*>(p_index_b),
reinterpret_cast<const IndexType*>(p_index_c),
p_embs,
p_indexs,
reinterpret_cast<const GammaDataType*>(p_gamma),
reinterpret_cast<const BetaDataType*>(p_beta),
NumRows,
EmbeddingDim,
IndexLength,
epsilon);
epsilon,
emb_elementwise_op);
}
using GridwiseSparseEmbedding =
GridwiseSparseEmbedding3ForwardLayernorm<EmbType,
GridwiseSparseEmbeddingsForwardLayernorm<EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(MakeOutputDescriptor(1, 1)),
EmbElementwiseOperation,
BlockSize,
DimClusterSize,
RowClusterSize,
DimPerBlock,
RowPerBlock,
DimThreadSize,
RowVectorSize>;
RowVectorSize,
NumEmbeddings>;
struct Invoker : public BaseInvoker
{
......@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_);
const auto kernel_main =
kernel_sparse_embedding3_forward_layernorm<GridwiseSparseEmbedding,
kernel_sparse_embeddings_forward_layernorm<GridwiseSparseEmbedding,
EmbType,
IndexType,
GammaDataType,
BetaDataType,
AccDataType,
OutType,
decltype(out_desc)>;
decltype(out_desc),
EmbElementwiseOperation,
NumEmbeddings>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
......@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3(BlockSize),
0,
arg.p_out_,
arg.p_emb_a_,
arg.p_emb_b_,
arg.p_emb_c_,
arg.p_index_a_,
arg.p_index_b_,
arg.p_index_c_,
arg.p_embs_,
arg.p_indexs_,
arg.p_gamma_,
arg.p_beta_,
out_desc,
arg.epsilon_);
arg.epsilon_,
arg.emb_elementwise_op_);
return (avg_time);
}
......@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static bool IsSupportedArgument(const Argument* p_arg)
{
return (RowPerBlock == p_arg->EmbeddingDim_) && (p_arg->NumRows_ % DimPerBlock == 0);
return (RowPerBlock == p_arg->EmbeddingDim_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
......@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto str = std::stringstream();
// clang-format off
str << "DeviceSparseEmbedding3ForwardLayernorm_"<< BlockSize << "_" <<
str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" <<
DimClusterSize << "x" << RowClusterSize << "_" <<
DimPerBlock << "x" << RowPerBlock << "_" <<
DimThreadSize << "x" << RowVectorSize;
......
......@@ -172,6 +172,42 @@ struct AddAdd
}
};
// C = A * B
// E = (C + D0) x D1
struct AddMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
const half_t& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (c + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const half_t y = (type_convert<half_t>(c) + d0) * d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
const half_t& d1) const
{
const float y = (c + d0) * d1;
e = y;
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
......@@ -278,6 +314,40 @@ struct Normalize
double epsilon_;
};
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
__host__ __device__ constexpr void operator()(T1& y,
const T1& x,
const T2& mean,
const T2& variance,
const T3& gamma,
const T4& beta) const
{
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value,
"Data type is not supported by this operation!");
using ck::type_convert;
using ck::math::sqrt;
T2 tmp_x, tmp_y;
tmp_x = type_convert<T2>(x);
tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
type_convert<T2>(gamma) +
type_convert<T2>(beta);
y = type_convert<T1>(tmp_y);
};
double epsilon_;
};
template <typename Y, typename X>
struct UnaryTypeConvert;
......
......@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave;
// 1-stage prefetch
template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{
// TODO: improve applicability
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep>
static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
// sync for Load threads()
block_sync_lds();
// global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// sync with math threads()
block_sync_lds();
// LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
}
}
};
template <typename TileMathThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmMathWave;
// 1- stage prefetch
template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename ABlockBuffer,
typename BBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf,
BBlockBuffer& b_block_buf,
const BlockwiseGemm& block_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment