Commit 3e38e358 authored by rocking's avatar rocking
Browse files

Add accElementwiseOp

parent 6ed9ab3a
...@@ -23,6 +23,7 @@ using GammaDataType = ck::half_t; ...@@ -23,6 +23,7 @@ using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
...@@ -32,6 +33,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType, ...@@ -32,6 +33,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceLayernorm<XDataType,
BetaDataType, BetaDataType,
AccDataType, AccDataType,
YDataType, YDataType,
PassThrough,
Rank, Rank,
NumReduceDim, NumReduceDim,
256, // BlockSize 256, // BlockSize
...@@ -136,7 +138,8 @@ int main() ...@@ -136,7 +138,8 @@ int main()
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
gamma_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(),
beta_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(),
y_dev.GetDeviceBuffer()); y_dev.GetDeviceBuffer(),
PassThrough{});
if(!device_instance.IsSupportedArgument(argument_ptr.get())) if(!device_instance.IsSupportedArgument(argument_ptr.get()))
{ {
......
...@@ -25,6 +25,7 @@ template <typename XDataType, ...@@ -25,6 +25,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename AccDataType, typename AccDataType,
typename YDataType, typename YDataType,
typename AccElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim, index_t NumReduceDim,
index_t BlockSize, index_t BlockSize,
...@@ -57,7 +58,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -57,7 +58,7 @@ struct DeviceLayernorm : public BaseOperator
NumReduceDim, NumReduceDim,
reduce::Add, reduce::Add,
PassThrough, // InElementwiseOperation PassThrough, // InElementwiseOperation
PassThrough, // AccElementwiseOperation AccElementwiseOperation, // AccElementwiseOperation
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
false, // PropagateNan false, // PropagateNan
false, // OutputIndex false, // OutputIndex
...@@ -109,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -109,6 +110,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_K, GridDesc_K,
BlockSize, BlockSize,
...@@ -128,6 +130,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -128,6 +130,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_K, GridDesc_K,
BlockSize, BlockSize,
...@@ -149,6 +152,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -149,6 +152,7 @@ struct DeviceLayernorm : public BaseOperator
const std::vector<index_t> gammaStrides, const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides, const std::vector<index_t> betaStrides,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccElementwiseOperation acc_elementwise_op,
AccDataType epsilon, AccDataType epsilon,
const XDataType* p_x, const XDataType* p_x,
const GammaDataType* p_gamma, const GammaDataType* p_gamma,
...@@ -165,7 +169,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -165,7 +169,7 @@ struct DeviceLayernorm : public BaseOperator
nullptr, nullptr,
p_y, p_y,
nullptr, nullptr,
PassThrough{}, acc_elementwise_op,
PassThrough{}), PassThrough{}),
epsilon_(epsilon), epsilon_(epsilon),
p_gamma_(p_gamma), p_gamma_(p_gamma),
...@@ -211,6 +215,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -211,6 +215,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_K> GridDesc_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric, : kernel_layernorm<GridwiseReduceLayernormGeneric,
...@@ -219,6 +224,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -219,6 +224,7 @@ struct DeviceLayernorm : public BaseOperator
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_K>; GridDesc_K>;
...@@ -237,7 +243,8 @@ struct DeviceLayernorm : public BaseOperator ...@@ -237,7 +243,8 @@ struct DeviceLayernorm : public BaseOperator
arg.in_dev_, arg.in_dev_,
arg.p_gamma_, arg.p_gamma_,
arg.p_beta_, arg.p_beta_,
arg.out_dev_); arg.out_dev_,
arg.acc_elementwise_op_);
return (avg_time); return (avg_time);
}; };
...@@ -296,13 +303,15 @@ struct DeviceLayernorm : public BaseOperator ...@@ -296,13 +303,15 @@ struct DeviceLayernorm : public BaseOperator
const void* p_x, const void* p_x,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y) void* p_y,
AccElementwiseOperation acc_elementwise_op)
{ {
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
betaStrides, betaStrides,
reduceDims, reduceDims,
acc_elementwise_op,
epsilon, epsilon,
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
......
...@@ -20,6 +20,7 @@ template <typename GridwiseReduction, ...@@ -20,6 +20,7 @@ template <typename GridwiseReduction,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_K> typename GridDesc_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
...@@ -31,7 +32,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, ...@@ -31,7 +32,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global) YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_k, gamma_grid_desc_k,
...@@ -42,7 +44,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, ...@@ -42,7 +44,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
p_x_global, p_x_global,
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global); p_y_global,
acc_elementwise_op);
}; };
template <typename XDataType, template <typename XDataType,
...@@ -50,6 +53,7 @@ template <typename XDataType, ...@@ -50,6 +53,7 @@ template <typename XDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
typename AccElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_K, typename GridDesc_K,
index_t BlockSize, index_t BlockSize,
...@@ -105,8 +109,6 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -105,8 +109,6 @@ struct GridwiseLayernorm_mk_to_mk
false, // ignored false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -122,7 +124,8 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -122,7 +124,8 @@ struct GridwiseLayernorm_mk_to_mk
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global) YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op)
{ {
if constexpr(SweepOnce) if constexpr(SweepOnce)
{ {
...@@ -225,7 +228,7 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -225,7 +228,7 @@ struct GridwiseLayernorm_mk_to_mk
YDataType, YDataType,
decltype(thread_buffer_desc_m_k), decltype(thread_buffer_desc_m_k),
GridDesc_M_K, GridDesc_M_K,
PassThroughOp, AccElementwiseOperation,
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
XSrcVectorDim, XSrcVectorDim,
...@@ -237,7 +240,7 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -237,7 +240,7 @@ struct GridwiseLayernorm_mk_to_mk
make_multi_index(block_global_id * M_BlockTileSize + make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize, thread_m_cluster_id * MThreadSliceSize,
thread_k_cluster_id * KThreadSliceSize), thread_k_cluster_id * KThreadSliceSize),
PassThroughOp{}); acc_elementwise_op);
// Copy x from Cache // Copy x from Cache
// one pass: fwd, second pass: bwd // one pass: fwd, second pass: bwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment