Commit 59613285 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add dy_elementwise_op

parent f4d67cf8
...@@ -253,6 +253,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -253,6 +253,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides.end(), scaleBiasMeanVarStrides.end(),
i_scaleBiasMeanVarStrides.begin()); i_scaleBiasMeanVarStrides.begin());
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using DeviceBatchNormBwdInstance = using DeviceBatchNormBwdInstance =
ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType, ck::tensor_operation::device::DeviceBatchNormBwdImpl<InOutDataType,
InOutDataType, InOutDataType,
...@@ -261,6 +263,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -261,6 +263,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
AccDataType, // ScaleDataType AccDataType, // ScaleDataType
AccDataType, // BiasDataType AccDataType, // BiasDataType
AccDataType, // MeanVarDataType AccDataType, // MeanVarDataType
PassThroughOp,
Rank, Rank,
NumReduceDim, NumReduceDim,
UseMultiblockInK, UseMultiblockInK,
...@@ -295,6 +298,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -295,6 +298,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr, haveSavedMeanInvVar ? savedMean_dev.GetDeviceBuffer() : nullptr,
haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr, haveSavedMeanInvVar ? savedInvVar_dev.GetDeviceBuffer() : nullptr,
epsilon, epsilon,
PassThroughOp{},
dx_dev.GetDeviceBuffer(), dx_dev.GetDeviceBuffer(),
bnScaleDiff_dev.GetDeviceBuffer(), bnScaleDiff_dev.GetDeviceBuffer(),
bnBiasDiff_dev.GetDeviceBuffer()); bnBiasDiff_dev.GetDeviceBuffer());
...@@ -350,7 +354,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -350,7 +354,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
AccDataType>; AccDataType,
PassThroughOp>;
auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{}; auto batchNormBwd_ref = ReferenceBatchNormBwdInstance{};
...@@ -370,6 +375,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification, ...@@ -370,6 +375,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
haveSavedMeanInvVar ? savedMean.mData.data() : nullptr, haveSavedMeanInvVar ? savedMean.mData.data() : nullptr,
haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr, haveSavedMeanInvVar ? savedInvVar.mData.data() : nullptr,
epsilon, epsilon,
PassThroughOp{},
dx_ref.mData.data(), dx_ref.mData.data(),
bnScaleDiff_ref.mData.data(), bnScaleDiff_ref.mData.data(),
bnBiasDiff_ref.mData.data()); bnBiasDiff_ref.mData.data());
......
...@@ -13,7 +13,7 @@ namespace ck { ...@@ -13,7 +13,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
struct DeviceBatchNormBwd : public BaseOperator struct DeviceBatchNormBwd : public BaseOperator
{ {
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
...@@ -34,6 +34,7 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -34,6 +34,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const void* p_savedMean, const void* p_savedMean,
const void* p_savedInvVar, const void* p_savedInvVar,
double epsilon, double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx, void* p_dx,
void* p_dscale, void* p_dscale,
void* p_dbias) = 0; void* p_dbias) = 0;
...@@ -41,8 +42,9 @@ struct DeviceBatchNormBwd : public BaseOperator ...@@ -41,8 +42,9 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <index_t Rank, index_t NumBatchNormReduceDim> template <index_t Rank, index_t NumBatchNormReduceDim, typename DyElementwiseOp>
using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim>>; using DeviceBatchNormBwdPtr =
std::unique_ptr<DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -29,6 +29,7 @@ template <typename XDataType, ...@@ -29,6 +29,7 @@ template <typename XDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
index_t Rank, index_t Rank,
index_t NumBatchNormReduceDim, index_t NumBatchNormReduceDim,
bool UseMultiblockInK, bool UseMultiblockInK,
...@@ -44,7 +45,8 @@ template <typename XDataType, ...@@ -44,7 +45,8 @@ template <typename XDataType,
index_t ScaleSrcDstVectorSize, index_t ScaleSrcDstVectorSize,
index_t BiasDstVectorSize, index_t BiasDstVectorSize,
index_t MeanVarSrcVectorSize> index_t MeanVarSrcVectorSize>
struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim> struct DeviceBatchNormBwdImpl
: public DeviceBatchNormBwd<Rank, NumBatchNormReduceDim, DyElementwiseOp>
{ {
static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
...@@ -199,6 +201,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -199,6 +201,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const ScaleDataType* p_scale, const ScaleDataType* p_scale,
const MeanVarDataType* p_savedMean, const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar, const MeanVarDataType* p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
double epsilon, double epsilon,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_dscale, ScaleDataType* p_dscale,
...@@ -212,6 +215,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -212,6 +215,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
p_scale_(p_scale), p_scale_(p_scale),
p_savedMean_(p_savedMean), p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar), p_savedInvVar_(p_savedInvVar),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx), p_dx_(p_dx),
p_dscale_(p_dscale), p_dscale_(p_dscale),
p_dbias_(p_dbias) p_dbias_(p_dbias)
...@@ -293,6 +297,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -293,6 +297,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const ScaleDataType* p_scale_; const ScaleDataType* p_scale_;
const MeanVarDataType* p_savedMean_; const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_; const MeanVarDataType* p_savedInvVar_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_dscale_; ScaleDataType* p_dscale_;
BiasDataType* p_dbias_; BiasDataType* p_dbias_;
...@@ -451,6 +456,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -451,6 +456,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType, ScaleDataType,
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, MeanVarCountGridDesc_M_K,
...@@ -473,6 +479,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -473,6 +479,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType, ScaleDataType,
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
DscaleDbiasGridDesc_M_K, DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
...@@ -548,6 +555,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -548,6 +555,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType, ScaleDataType,
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, MeanVarCountGridDesc_M_K,
...@@ -562,6 +570,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -562,6 +570,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType, ScaleDataType,
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
DscaleDbiasGridDesc_M_K, DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
...@@ -596,6 +605,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -596,6 +605,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
: static_cast<const MeanVarDataType*>(arg.workspace_variance), : static_cast<const MeanVarDataType*>(arg.workspace_variance),
arg.haveSavedMeanInvVar_ ? nullptr arg.haveSavedMeanInvVar_ ? nullptr
: static_cast<const int32_t*>(arg.workspace_count), : static_cast<const int32_t*>(arg.workspace_count),
arg.dy_elementwise_op_,
arg.haveSavedMeanInvVar_ arg.haveSavedMeanInvVar_
? nullptr ? nullptr
: static_cast<MeanVarDataType*>(arg.workspace_savedMean), : static_cast<MeanVarDataType*>(arg.workspace_savedMean),
...@@ -635,6 +645,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -635,6 +645,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.p_x_, arg.p_x_,
arg.p_dy_, arg.p_dy_,
arg.p_scale_, arg.p_scale_,
arg.dy_elementwise_op_,
arg.p_dx_, arg.p_dx_,
arg.p_dscale_, arg.p_dscale_,
arg.p_dbias_); arg.p_dbias_);
...@@ -655,6 +666,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -655,6 +666,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType, ScaleDataType,
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
MeanVarGridDesc_M, MeanVarGridDesc_M,
...@@ -681,6 +693,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -681,6 +693,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType, ScaleDataType,
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
DyElementwiseOp,
XYGridDesc_M_K, XYGridDesc_M_K,
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
MeanVarGridDesc_M, MeanVarGridDesc_M,
...@@ -707,6 +720,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -707,6 +720,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.haveSavedMeanInvVar_, arg.haveSavedMeanInvVar_,
arg.p_savedMean_, arg.p_savedMean_,
arg.p_savedInvVar_, arg.p_savedInvVar_,
arg.dy_elementwise_op_,
arg.p_dx_, arg.p_dx_,
arg.p_dscale_, arg.p_dscale_,
arg.p_dbias_); arg.p_dbias_);
...@@ -800,6 +814,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -800,6 +814,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const void* p_savedMean, const void* p_savedMean,
const void* p_savedInvVar, const void* p_savedInvVar,
double epsilon, double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx, void* p_dx,
void* p_dscale, void* p_dscale,
void* p_dbias) override void* p_dbias) override
...@@ -818,6 +833,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -818,6 +833,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
static_cast<const ScaleDataType*>(p_scale), static_cast<const ScaleDataType*>(p_scale),
static_cast<const MeanVarDataType*>(p_savedMean), static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar), static_cast<const MeanVarDataType*>(p_savedInvVar),
dy_elementwise_op,
epsilon, epsilon,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale), static_cast<ScaleDataType*>(p_dscale),
......
...@@ -18,6 +18,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_, ...@@ -18,6 +18,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K, typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
...@@ -41,6 +42,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -41,6 +42,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) BiasDataType* const __restrict__ p_dbias)
...@@ -63,6 +65,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -63,6 +65,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
p_x, p_x,
p_dy, p_dy,
p_scale, p_scale,
dy_elementwise_op,
p_dx, p_dx,
p_dscale, p_dscale,
p_dbias); p_dbias);
...@@ -75,6 +78,7 @@ template <typename XDataType, ...@@ -75,6 +78,7 @@ template <typename XDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K, typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
...@@ -163,6 +167,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -163,6 +167,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy, const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale, const ScaleDataType* const __restrict__ p_scale,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) BiasDataType* const __restrict__ p_dbias)
...@@ -498,6 +503,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -498,6 +503,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
constexpr auto offset = constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM]; inv_var_thread_buf[iM];
......
...@@ -19,6 +19,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_, ...@@ -19,6 +19,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K, typename MeanVarCountGridDesc_M_K,
...@@ -39,6 +40,7 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -39,6 +40,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count, const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean, MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -61,6 +63,7 @@ __global__ void kernel_welford_second_half_reduce_first_half( ...@@ -61,6 +63,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
p_in_welford_mean, p_in_welford_mean,
p_in_welford_variance, p_in_welford_variance,
p_in_welford_count, p_in_welford_count,
dy_elementwise_op,
p_out_welford_mean, p_out_welford_mean,
p_out_welford_inv_variance, p_out_welford_inv_variance,
p_x, p_x,
...@@ -75,6 +78,7 @@ template <typename XDataType, ...@@ -75,6 +78,7 @@ template <typename XDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K, typename MeanVarCountGridDesc_M_K,
...@@ -165,6 +169,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -165,6 +169,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_mean,
const MeanVarDataType* const __restrict__ p_in_welford_variance, const MeanVarDataType* const __restrict__ p_in_welford_variance,
const int32_t* const __restrict__ p_in_welford_count, const int32_t* const __restrict__ p_in_welford_count,
const DyElementwiseOp dy_elementwise_op,
MeanVarDataType* const __restrict__ p_out_welford_mean, MeanVarDataType* const __restrict__ p_out_welford_mean,
MeanVarDataType* const __restrict__ p_out_welford_inv_variance, MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x, const XDataType* const __restrict__ p_x,
...@@ -480,6 +485,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -480,6 +485,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
constexpr auto offset = constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM]; inv_var_thread_buf[iM];
......
...@@ -23,6 +23,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_, ...@@ -23,6 +23,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M, typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
...@@ -44,6 +45,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -44,6 +45,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
bool haveSavedMeanInvVar, bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean, const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) BiasDataType* const __restrict__ p_dbias)
...@@ -64,6 +66,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford( ...@@ -64,6 +66,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
haveSavedMeanInvVar, haveSavedMeanInvVar,
p_savedMean, p_savedMean,
p_savedInvVar, p_savedInvVar,
dy_elementwise_op,
p_dx, p_dx,
p_dscale, p_dscale,
p_dbias); p_dbias);
...@@ -76,6 +79,7 @@ template <typename XDataType, ...@@ -76,6 +79,7 @@ template <typename XDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename DyElementwiseOp,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename ScaleBiasGridDesc_M, typename ScaleBiasGridDesc_M,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
...@@ -173,6 +177,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -173,6 +177,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
bool haveSavedMeanInvVar, bool haveSavedMeanInvVar,
const MeanVarDataType* const __restrict__ p_savedMean, const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar, const MeanVarDataType* const __restrict__ p_savedInvVar,
const DyElementwiseOp dy_elementwise_op,
DxDataType* const __restrict__ p_dx, DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_dscale, ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias) BiasDataType* const __restrict__ p_dbias)
...@@ -455,6 +460,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -455,6 +460,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr auto offset = constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM]; inv_var_thread_buf[iM];
...@@ -531,6 +539,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford ...@@ -531,6 +539,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr auto offset = constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM]; inv_var_thread_buf[iM];
......
...@@ -19,8 +19,10 @@ template <typename XDataType, ...@@ -19,8 +19,10 @@ template <typename XDataType,
typename AccDataType, typename AccDataType,
typename ScaleDataType, typename ScaleDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType> typename MeanVarDataType,
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatchNormBwd<4, 3> typename DyElementwiseOp>
struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
: public device::DeviceBatchNormBwd<4, 3, DyElementwiseOp>
{ {
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
...@@ -39,6 +41,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -39,6 +41,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const MeanVarDataType* p_savedMean, const MeanVarDataType* p_savedMean,
const MeanVarDataType* p_savedInvVar, const MeanVarDataType* p_savedInvVar,
double epsilon, double epsilon,
const DyElementwiseOp dy_elementwise_op,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_dscale, ScaleDataType* p_dscale,
BiasDataType* p_dbias) BiasDataType* p_dbias)
...@@ -48,6 +51,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -48,6 +51,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
p_savedMean_(p_savedMean), p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar), p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon), epsilon_(epsilon),
dy_elementwise_op_(dy_elementwise_op),
p_dx_(p_dx), p_dx_(p_dx),
p_dscale_(p_dscale), p_dscale_(p_dscale),
p_dbias_(p_dbias) p_dbias_(p_dbias)
...@@ -79,6 +83,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -79,6 +83,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const MeanVarDataType* p_savedInvVar_; const MeanVarDataType* p_savedInvVar_;
double epsilon_; double epsilon_;
const DyElementwiseOp dy_elementwise_op_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_dscale_; ScaleDataType* p_dscale_;
...@@ -165,6 +170,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -165,6 +170,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType norm_x = (x - mean) * invVar; AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]); AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
arg.dy_elementwise_op_(dy, dy);
dbias += dy; dbias += dy;
dscale += norm_x * dy; dscale += norm_x * dy;
}; };
...@@ -194,6 +201,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -194,6 +201,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]); AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]); AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
arg.dy_elementwise_op_(dy, dy);
AccDataType tmpVal = norm_x * dscale; AccDataType tmpVal = norm_x * dscale;
AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar * AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar *
...@@ -258,6 +267,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -258,6 +267,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const void* p_savedMean, const void* p_savedMean,
const void* p_savedInvVar, const void* p_savedInvVar,
double epsilon, double epsilon,
const DyElementwiseOp dy_elementwise_op,
void* p_dx, void* p_dx,
void* p_dscale, void* p_dscale,
void* p_dbias) override void* p_dbias) override
...@@ -277,6 +287,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -277,6 +287,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
static_cast<const MeanVarDataType*>(p_savedMean), static_cast<const MeanVarDataType*>(p_savedMean),
static_cast<const MeanVarDataType*>(p_savedInvVar), static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon, epsilon,
dy_elementwise_op,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_dscale), static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias)); static_cast<BiasDataType*>(p_dbias));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment