Commit 3e38e358 authored by rocking's avatar rocking
Browse files

Add accElementwiseOp

parent 6ed9ab3a
......@@ -23,6 +23,7 @@ using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t;
using YDataType = ck::half_t;
using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2;
constexpr int NumReduceDim = 1;
......@@ -32,6 +33,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
BetaDataType,
AccDataType,
YDataType,
PassThrough,
Rank,
NumReduceDim,
256, // BlockSize
......@@ -136,7 +138,8 @@ int main()
x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer());
y_dev.GetDeviceBuffer(),
PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{
......
......@@ -25,6 +25,7 @@ template <typename XDataType,
typename BetaDataType,
typename AccDataType,
typename YDataType,
typename AccElementwiseOperation,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
......@@ -56,8 +57,8 @@ struct DeviceLayernorm : public BaseOperator
Rank,
NumReduceDim,
reduce::Add,
PassThrough, // InElementwiseOperation
PassThrough, // AccElementwiseOperation
PassThrough, // InElementwiseOperation
AccElementwiseOperation, // AccElementwiseOperation
InMemoryDataOperationEnum::Set,
false, // PropagateNan
false, // OutputIndex
......@@ -109,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K,
BlockSize,
......@@ -128,6 +130,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K,
BlockSize,
......@@ -149,6 +152,7 @@ struct DeviceLayernorm : public BaseOperator
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon,
const XDataType* p_x,
const GammaDataType* p_gamma,
......@@ -165,7 +169,7 @@ struct DeviceLayernorm : public BaseOperator
nullptr,
p_y,
nullptr,
PassThrough{},
acc_elementwise_op,
PassThrough{}),
epsilon_(epsilon),
p_gamma_(p_gamma),
......@@ -211,6 +215,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric,
......@@ -219,6 +224,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K,
GridDesc_K>;
......@@ -237,7 +243,8 @@ struct DeviceLayernorm : public BaseOperator
arg.in_dev_,
arg.p_gamma_,
arg.p_beta_,
arg.out_dev_);
arg.out_dev_,
arg.acc_elementwise_op_);
return (avg_time);
};
......@@ -296,13 +303,15 @@ struct DeviceLayernorm : public BaseOperator
const void* p_x,
const void* p_gamma,
const void* p_beta,
void* p_y)
void* p_y,
AccElementwiseOperation acc_elementwise_op)
{
return std::make_unique<Argument>(lengths,
xStrides,
gammaStrides,
betaStrides,
reduceDims,
acc_elementwise_op,
epsilon,
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
......
......@@ -20,6 +20,7 @@ template <typename GridwiseReduction,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
......@@ -31,7 +32,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global)
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_k,
......@@ -42,7 +44,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
p_x_global,
p_gamma_global,
p_beta_global,
p_y_global);
p_y_global,
acc_elementwise_op);
};
template <typename XDataType,
......@@ -50,6 +53,7 @@ template <typename XDataType,
typename BetaDataType,
typename YDataType,
typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K,
typename GridDesc_K,
index_t BlockSize,
......@@ -105,8 +109,6 @@ struct GridwiseLayernorm_mk_to_mk
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -122,7 +124,8 @@ struct GridwiseLayernorm_mk_to_mk
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global)
YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{
if constexpr(SweepOnce)
{
......@@ -225,7 +228,7 @@ struct GridwiseLayernorm_mk_to_mk
YDataType,
decltype(thread_buffer_desc_m_k),
GridDesc_M_K,
PassThroughOp,
AccElementwiseOperation,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
......@@ -237,7 +240,7 @@ struct GridwiseLayernorm_mk_to_mk
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{});
acc_elementwise_op);
// Copy x from Cache
// one pass: fwd, second pass: bwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment