Commit 59613285 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add dy_elementwise_op

parent f4d67cf8
......@@ -253,6 +253,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin());
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormBwdInstance =
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType,
InOutDataType,
......@@ -261,6 +263,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
AccDataType, // ScaleDataType
AccDataType, // BiasDataType
AccDataType, // MeanVarDataType
PassThroughOp,
Rank,
NumReduceDim,
UseMultiblockInK,
......@@ -295,6 +298,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
epsilon,
PassThroughOp{},
dx_dev.GetDeviceBuffer(),
bnScaleDiff_dev.GetDeviceBuffer(),
bnBiasDiff_dev.GetDeviceBuffer());
......@@ -350,7 +354,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
AccDataType,
AccDataType,
AccDataType,
AccDataType>;
AccDataType,
PassThroughOp>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
......@@ -370,6 +375,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
epsilon,
PassThroughOp{},
dx_ref.mData.data(),
bnScaleDiff_ref.mData.data(),
bnBiasDiff_ref.mData.data());
......
......@@ -13,7 +13,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim>
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
struct DeviceBatchNormBwd : public BaseOperator
{
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
......@@ -34,6 +34,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) = 0;
......@@ -41,8 +42,9 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim>>;
template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -29,6 +29,7 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank,
index_t NumBatchNormReduceDim,
bool UseMultiblockInK,
......@@ -44,7 +45,8 @@ template <typename XDataType,
index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize,
index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim>
struct DeviceBatchNormBwdImpl
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
......@@ -199,6 +201,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
double epsilon,
DxDataType* p_dx,
ScaleDataType* p_dscale,
......@@ -212,6 +215,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
p_scale_(p_scale),
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
......@@ -293,6 +297,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
......@@ -451,6 +456,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K,
......@@ -473,6 +479,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M,
......@@ -548,6 +555,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K,
......@@ -562,6 +570,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M,
......@@ -596,6 +605,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
: static_cast<const MeanVarDataType*>(arg.workspace_variance),
arg.haveSavedMeanInvVar_ ? nullptr
: static_cast<const int32_t*>(arg.workspace_count),
arg.dy_elementwise_op_,
arg.haveSavedMeanInvVar_
? nullptr
: static_cast<MeanVarDataType*>(arg.workspace_savedMean),
......@@ -635,6 +645,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.p_x_,
arg.p_dy_,
arg.p_scale_,
arg.dy_elementwise_op_,
arg.p_dx_,
arg.p_dscale_,
arg.p_dbias_);
......@@ -655,6 +666,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
ScaleBiasGridDesc_M,
MeanVarGridDesc_M,
......@@ -681,6 +693,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType,
BiasDataType,
MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K,
ScaleBiasGridDesc_M,
MeanVarGridDesc_M,
......@@ -707,6 +720,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.haveSavedMeanInvVar_,
arg.p_savedMean_,
arg.p_savedInvVar_,
arg.dy_elementwise_op_,
arg.p_dx_,
arg.p_dscale_,
arg.p_dbias_);
......@@ -800,6 +814,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
......@@ -818,6 +833,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
dy_elementwise_op,
epsilon,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
......
......@@ -18,6 +18,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
......@@ -41,6 +42,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
......@@ -63,6 +65,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
p_x,
p_dy,
p_scale,
dy_elementwise_op,
p_dx,
p_dscale,
p_dbias);
......@@ -75,6 +78,7 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
......@@ -163,6 +167,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
......@@ -498,6 +503,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
......
......@@ -19,6 +19,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
......@@ -39,6 +40,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
......@@ -61,6 +63,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
p_in_welford_mean,
p_in_welford_variance,
p_in_welford_count,
dy_elementwise_op,
p_out_welford_mean,
p_out_welford_inv_variance,
p_x,
......@@ -75,6 +78,7 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
......@@ -165,6 +169,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
......@@ -480,6 +485,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
......
......@@ -23,6 +23,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
......@@ -44,6 +45,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
......@@ -64,6 +66,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
haveSavedMeanInvVar,
p_savedMean,
p_savedInvVar,
dy_elementwise_op,
p_dx,
p_dscale,
p_dbias);
......@@ -76,6 +79,7 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M,
......@@ -173,6 +177,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
......@@ -455,6 +460,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
......@@ -531,6 +539,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
......
......@@ -19,8 +19,10 @@ template <typename XDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatchNormBwd<4, 3>
typename MeanVarDataType,
typename DyElementwiseOp>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
{
struct Argument : public device::BaseArgument
{
......@@ -39,6 +41,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx,
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
......@@ -48,6 +51,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx),
p_dscale_(p_dscale),
p_dbias_(p_dbias)
......@@ -79,6 +83,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const MeanVarDataType* p_savedInvVar_;
double epsilon_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_;
ScaleDataType* p_dscale_;
......@@ -165,6 +170,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy;
dscale += norm_x * dy;
};
......@@ -194,6 +201,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar *
......@@ -258,6 +267,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const void* p_savedMean,
const void* p_savedInvVar,
double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx,
void* p_dscale,
void* p_dbias) override
......@@ -277,6 +287,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment