"examples/vscode:/vscode.git/clone" did not exist on "0240d4191adfe22e66e43826261795c0f1217651"
Commit bb220a7a authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Renaming again

parent 213187f6
...@@ -427,7 +427,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -427,7 +427,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor( DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
arg.invariant_length, arg.blkGroupSize); arg.invariant_length, arg.blkGroupSize);
const auto scale_bias_diff_grid_desc_m_g = const auto dscale_dbias_grid_desc_m_g =
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor( DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
arg.invariant_length, arg.blkGroupSize); arg.invariant_length, arg.blkGroupSize);
...@@ -435,14 +435,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -435,14 +435,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor( DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
arg.invariant_length, arg.blkGroupSize); arg.invariant_length, arg.blkGroupSize);
const auto scale_bias_diff_grid_desc_m_k = const auto dscale_dbias_grid_desc_m_k =
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor( DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
arg.invariant_length, arg.blkGroupSize); arg.invariant_length, arg.blkGroupSize);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using ScaleBiasDiffGridDesc_M_G = decltype(scale_bias_diff_grid_desc_m_g); using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g);
using ScaleBiasDiffGridDesc_M_K = decltype(scale_bias_diff_grid_desc_m_k); using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k);
using GridwiseWelfordSecondHalfReduceFirstHalf_ = using GridwiseWelfordSecondHalfReduceFirstHalf_ =
GridwiseWelfordSecondHalfReduceFirstHalf<XDataType, GridwiseWelfordSecondHalfReduceFirstHalf<XDataType,
...@@ -454,7 +454,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -454,7 +454,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
XYGridDesc_M_K, XYGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, MeanVarCountGridDesc_M_K,
ScaleBiasDiffGridDesc_M_G, DscaleDbiasGridDesc_M_G,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -474,7 +474,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -474,7 +474,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
XYGridDesc_M_K, XYGridDesc_M_K,
ScaleBiasDiffGridDesc_M_K, DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
ScaleBiasGridDesc_M, ScaleBiasGridDesc_M,
BlockSize, BlockSize,
...@@ -551,7 +551,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -551,7 +551,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
XYGridDesc_M_K, XYGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K, MeanVarCountGridDesc_M_K,
ScaleBiasDiffGridDesc_M_G>; DscaleDbiasGridDesc_M_G>;
const auto kern_reduce_second_half_batchnorm_backward_final = const auto kern_reduce_second_half_batchnorm_backward_final =
kernel_reduce_second_half_batchnorm_backward_final< kernel_reduce_second_half_batchnorm_backward_final<
...@@ -563,11 +563,11 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -563,11 +563,11 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
BiasDataType, BiasDataType,
MeanVarDataType, MeanVarDataType,
XYGridDesc_M_K, XYGridDesc_M_K,
ScaleBiasDiffGridDesc_M_K, DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M, MeanVarGridDesc_M,
ScaleBiasGridDesc_M>; ScaleBiasGridDesc_M>;
index_t numScaleBiasDiffBlockTileIteration = index_t numDscaleDbiasBlockTileIteration =
(arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize; (arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize;
avg_time += launch_and_time_kernel( avg_time += launch_and_time_kernel(
...@@ -580,10 +580,10 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -580,10 +580,10 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.dy_grid_desc_m_k, arg.dy_grid_desc_m_k,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k, mean_var_count_grid_desc_m_k,
scale_bias_diff_grid_desc_m_g, dscale_dbias_grid_desc_m_g,
arg.blkGroupSize, arg.blkGroupSize,
arg.numBlockTileIteration, arg.numBlockTileIteration,
numScaleBiasDiffBlockTileIteration, numDscaleDbiasBlockTileIteration,
arg.epsilon_, arg.epsilon_,
arg.haveSavedMeanInvVar_, arg.haveSavedMeanInvVar_,
arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr, arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr,
...@@ -616,14 +616,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu ...@@ -616,14 +616,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.x_grid_desc_m_k, arg.x_grid_desc_m_k,
arg.dy_grid_desc_m_k, arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k, arg.dx_grid_desc_m_k,
scale_bias_diff_grid_desc_m_k, dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m, arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m, arg.scale_grid_desc_m,
arg.bias_grid_desc_m, arg.bias_grid_desc_m,
arg.blkGroupSize, arg.blkGroupSize,
arg.reduce_length, arg.reduce_length,
arg.numBlockTileIteration, arg.numBlockTileIteration,
numScaleBiasDiffBlockTileIteration, numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale), static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias), static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_ arg.haveSavedMeanInvVar_
......
...@@ -19,21 +19,21 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_, ...@@ -19,21 +19,21 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_K, typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M> typename ScaleBiasGridDesc_M>
__global__ void kernel_reduce_second_half_batchnorm_backward_final( __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K dscale_dbias_grid_desc_m_k, const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
...@@ -55,7 +55,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final( ...@@ -55,7 +55,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
blkgroup_size, blkgroup_size,
reduce_size, reduce_size,
num_xy_k_block_tile_iteration, num_xy_k_block_tile_iteration,
num_scale_bias_diff_k_block_tile_iteration, num_dscale_dbias_k_block_tile_iteration,
p_reduce_dscale, p_reduce_dscale,
p_reduce_dbias, p_reduce_dbias,
p_mean, p_mean,
...@@ -76,7 +76,7 @@ template <typename XDataType, ...@@ -76,7 +76,7 @@ template <typename XDataType,
typename BiasDataType, typename BiasDataType,
typename MeanVarDataType, typename MeanVarDataType,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_K, typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M, typename ScaleBiasGridDesc_M,
index_t BlockSize, index_t BlockSize,
...@@ -148,14 +148,14 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -148,14 +148,14 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k, const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k, const XYGridDesc_M_K& dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K& dscale_dbias_grid_desc_m_k, const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m, const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m, const ScaleBiasGridDesc_M& bias_grid_desc_m,
index_t blkgroup_size, index_t blkgroup_size,
long_index_t reduce_size, long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale, const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias, const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean, const MeanVarDataType* const __restrict__ p_mean,
...@@ -220,7 +220,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -220,7 +220,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
auto threadwise_dscale_load_m_k = auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType, ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType, AccDataType,
ScaleBiasDiffGridDesc_M_K, DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -236,7 +236,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -236,7 +236,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
auto threadwise_dbias_load_m_k = auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType, ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType, AccDataType,
ScaleBiasDiffGridDesc_M_K, DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -305,7 +305,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal ...@@ -305,7 +305,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
dbias_thread_buf(I) = type_convert<AccDataType>(0.0f); dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
}); });
for(index_t reducedTiles = 0; reducedTiles < num_scale_bias_diff_k_block_tile_iteration; for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles) ++reducedTiles)
{ {
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k, threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
......
...@@ -22,13 +22,13 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_, ...@@ -22,13 +22,13 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K, typename MeanVarCountGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_G> typename DscaleDbiasGridDesc_M_G>
__global__ void kernel_welford_second_half_reduce_first_half( __global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G dscale_dbias_grid_desc_m_g, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration,
...@@ -78,7 +78,7 @@ template <typename XDataType, ...@@ -78,7 +78,7 @@ template <typename XDataType,
typename XYGridDesc_M_K, typename XYGridDesc_M_K,
typename MeanVarGridDesc_M, typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K, typename MeanVarCountGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_G, typename DscaleDbiasGridDesc_M_G,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -154,7 +154,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -154,7 +154,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const XYGridDesc_M_K& dy_grid_desc_m_k, const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m, const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G& dscale_dbias_grid_desc_m_g, const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size, index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration, index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration,
...@@ -504,7 +504,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -504,7 +504,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType, ScaleDataType,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ScaleBiasDiffGridDesc_M_G, DscaleDbiasGridDesc_M_G,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -523,7 +523,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf ...@@ -523,7 +523,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
ThreadwiseTensorSliceTransfer_v1r3<AccDataType, ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType, BiasDataType,
decltype(thread_buffer_desc_m_1), decltype(thread_buffer_desc_m_1),
ScaleBiasDiffGridDesc_M_G, DscaleDbiasGridDesc_M_G,
PassThroughOp, PassThroughOp,
ThreadBufferLengths_M_1, ThreadBufferLengths_M_1,
Sequence<0, 1>, Sequence<0, 1>,
......
...@@ -40,8 +40,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -40,8 +40,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const MeanVarDataType* p_savedInvVar, const MeanVarDataType* p_savedInvVar,
double epsilon, double epsilon,
DxDataType* p_dx, DxDataType* p_dx,
ScaleDataType* p_scaleDiff, ScaleDataType* p_dscale,
BiasDataType* p_biasDiff) BiasDataType* p_dbias)
: p_x_(p_x), : p_x_(p_x),
p_dy_(p_dy), p_dy_(p_dy),
p_scale_(p_scale), p_scale_(p_scale),
...@@ -49,8 +49,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -49,8 +49,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
p_savedInvVar_(p_savedInvVar), p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon), epsilon_(epsilon),
p_dx_(p_dx), p_dx_(p_dx),
p_scaleDiff_(p_scaleDiff), p_dscale_(p_dscale),
p_biasDiff_(p_biasDiff) p_dbias_(p_dbias)
{ {
ignore = xStrides; ignore = xStrides;
ignore = dyStrides; ignore = dyStrides;
...@@ -81,8 +81,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -81,8 +81,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
double epsilon_; double epsilon_;
DxDataType* p_dx_; DxDataType* p_dx_;
ScaleDataType* p_scaleDiff_; ScaleDataType* p_dscale_;
BiasDataType* p_biasDiff_; BiasDataType* p_dbias_;
bool haveSavedMeanInvVar_; bool haveSavedMeanInvVar_;
...@@ -113,7 +113,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -113,7 +113,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
meansquare = type_convert<AccDataType>(0.0f); meansquare = type_convert<AccDataType>(0.0f);
mean = type_convert<AccDataType>(0.0f); mean = type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, invVariance // compute mean, meanquare, variance, inv-variance
for(index_t iN = 0; iN < arg.n_; iN++) for(index_t iN = 0; iN < arg.n_; iN++)
{ {
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
...@@ -142,13 +142,12 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -142,13 +142,12 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance); std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
}; };
AccDataType bnBiasDiff = type_convert<AccDataType>(0.0f); // Sum on NHW of dy AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType bnScaleDiff = AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * invVariance // 1) calculate dy * (x - mean) * inv-variance
// 2) calculate Sum on NHWC of dy // 2) calculate sum(dy) on NHW dimensions
// 3) calculate Sum on NHWC of dy * norm_x // 3) calculate sum(dy * norm_x) on NHW dimensions
for(index_t iN = 0; iN < arg.n_; iN++) for(index_t iN = 0; iN < arg.n_; iN++)
{ {
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
...@@ -166,17 +165,17 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -166,17 +165,17 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType norm_x = (x - mean) * invVar; AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]); AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
bnBiasDiff += dy; dbias += dy;
bnScaleDiff += norm_x * dy; dscale += norm_x * dy;
}; };
} }
}; };
arg.p_scaleDiff_[offset_C] = type_convert<ScaleDataType>(bnScaleDiff); arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_biasDiff_[offset_C] = type_convert<BiasDataType>(bnBiasDiff); arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
// 1) calculate tmp = scaleDiff * (x - mean) * invVariance // 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * invVariance * scale * (nhw * dy - biasDiff - tmp) // 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++) for(index_t iN = 0; iN < arg.n_; iN++)
{ {
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_; index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
...@@ -195,10 +194,10 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -195,10 +194,10 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]); AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]); AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType tmpVal = norm_x * bnScaleDiff; AccDataType tmpVal = norm_x * dscale;
AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar * AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar *
scale * (reduceSize * dy - bnBiasDiff - tmpVal); scale * (reduceSize * dy - dbias - tmpVal);
arg.p_dx_[offset] = type_convert<XDataType>(dx); arg.p_dx_[offset] = type_convert<XDataType>(dx);
}; };
...@@ -260,8 +259,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -260,8 +259,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const void* p_savedInvVar, const void* p_savedInvVar,
double epsilon, double epsilon,
void* p_dx, void* p_dx,
void* p_scaleDiff, void* p_dscale,
void* p_biasDiff) override void* p_dbias) override
{ {
return std::make_unique<Argument>(xyLengths, return std::make_unique<Argument>(xyLengths,
xStrides, xStrides,
...@@ -279,8 +278,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch ...@@ -279,8 +278,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
static_cast<const MeanVarDataType*>(p_savedInvVar), static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon, epsilon,
static_cast<DxDataType*>(p_dx), static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_scaleDiff), static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_biasDiff)); static_cast<BiasDataType*>(p_dbias));
}; };
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment