Commit bb220a7a authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Renaming again

parent 213187f6
......@@ -427,7 +427,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
const auto scale_bias_diff_grid_desc_m_g =
const auto dscale_dbias_grid_desc_m_g =
DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
......@@ -435,14 +435,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
const auto scale_bias_diff_grid_desc_m_k =
const auto dscale_dbias_grid_desc_m_k =
DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor(
arg.invariant_length, arg.blkGroupSize);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using ScaleBiasDiffGridDesc_M_G = decltype(scale_bias_diff_grid_desc_m_g);
using ScaleBiasDiffGridDesc_M_K = decltype(scale_bias_diff_grid_desc_m_k);
using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g);
using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k);
using GridwiseWelfordSecondHalfReduceFirstHalf_ =
GridwiseWelfordSecondHalfReduceFirstHalf<XDataType,
......@@ -454,7 +454,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
XYGridDesc_M_K,
MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K,
ScaleBiasDiffGridDesc_M_G,
DscaleDbiasGridDesc_M_G,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
......@@ -474,7 +474,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
BiasDataType,
MeanVarDataType,
XYGridDesc_M_K,
ScaleBiasDiffGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M,
ScaleBiasGridDesc_M,
BlockSize,
......@@ -551,7 +551,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
XYGridDesc_M_K,
MeanVarGridDesc_M,
MeanVarCountGridDesc_M_K,
ScaleBiasDiffGridDesc_M_G>;
DscaleDbiasGridDesc_M_G>;
const auto kern_reduce_second_half_batchnorm_backward_final =
kernel_reduce_second_half_batchnorm_backward_final<
......@@ -563,11 +563,11 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
BiasDataType,
MeanVarDataType,
XYGridDesc_M_K,
ScaleBiasDiffGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
MeanVarGridDesc_M,
ScaleBiasGridDesc_M>;
index_t numScaleBiasDiffBlockTileIteration =
index_t numDscaleDbiasBlockTileIteration =
(arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize;
avg_time += launch_and_time_kernel(
......@@ -580,10 +580,10 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.dy_grid_desc_m_k,
arg.mean_var_grid_desc_m,
mean_var_count_grid_desc_m_k,
scale_bias_diff_grid_desc_m_g,
dscale_dbias_grid_desc_m_g,
arg.blkGroupSize,
arg.numBlockTileIteration,
numScaleBiasDiffBlockTileIteration,
numDscaleDbiasBlockTileIteration,
arg.epsilon_,
arg.haveSavedMeanInvVar_,
arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr,
......@@ -616,14 +616,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.x_grid_desc_m_k,
arg.dy_grid_desc_m_k,
arg.dx_grid_desc_m_k,
scale_bias_diff_grid_desc_m_k,
dscale_dbias_grid_desc_m_k,
arg.mean_var_grid_desc_m,
arg.scale_grid_desc_m,
arg.bias_grid_desc_m,
arg.blkGroupSize,
arg.reduce_length,
arg.numBlockTileIteration,
numScaleBiasDiffBlockTileIteration,
numDscaleDbiasBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_
......
......@@ -19,21 +19,21 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename BiasDataType,
typename MeanVarDataType,
typename XYGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M>
__global__ void kernel_reduce_second_half_batchnorm_backward_final(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const XYGridDesc_M_K dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K dscale_dbias_grid_desc_m_k,
const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const ScaleBiasGridDesc_M scale_grid_desc_m,
const ScaleBiasGridDesc_M bias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
......@@ -55,7 +55,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
blkgroup_size,
reduce_size,
num_xy_k_block_tile_iteration,
num_scale_bias_diff_k_block_tile_iteration,
num_dscale_dbias_k_block_tile_iteration,
p_reduce_dscale,
p_reduce_dbias,
p_mean,
......@@ -76,7 +76,7 @@ template <typename XDataType,
typename BiasDataType,
typename MeanVarDataType,
typename XYGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_K,
typename DscaleDbiasGridDesc_M_K,
typename MeanVarGridDesc_M,
typename ScaleBiasGridDesc_M,
index_t BlockSize,
......@@ -148,14 +148,14 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
__device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
const XYGridDesc_M_K& dy_grid_desc_m_k,
const XYGridDesc_M_K& dx_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const ScaleBiasGridDesc_M& scale_grid_desc_m,
const ScaleBiasGridDesc_M& bias_grid_desc_m,
index_t blkgroup_size,
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration,
index_t num_dscale_dbias_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
......@@ -220,7 +220,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
auto threadwise_dscale_load_m_k =
ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
AccDataType,
ScaleBiasDiffGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
......@@ -236,7 +236,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
auto threadwise_dbias_load_m_k =
ThreadwiseTensorSliceTransfer_v2<BiasDataType,
AccDataType,
ScaleBiasDiffGridDesc_M_K,
DscaleDbiasGridDesc_M_K,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
......@@ -305,7 +305,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
dbias_thread_buf(I) = type_convert<AccDataType>(0.0f);
});
for(index_t reducedTiles = 0; reducedTiles < num_scale_bias_diff_k_block_tile_iteration;
for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration;
++reducedTiles)
{
threadwise_dscale_load_m_k.Run(dscale_dbias_grid_desc_m_k,
......
......@@ -22,13 +22,13 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_G>
typename DscaleDbiasGridDesc_M_G>
__global__ void kernel_welford_second_half_reduce_first_half(
const XYGridDesc_M_K x_grid_desc_m_k,
const XYGridDesc_M_K dy_grid_desc_m_k,
const MeanVarGridDesc_M mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G dscale_dbias_grid_desc_m_g,
const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
......@@ -78,7 +78,7 @@ template <typename XDataType,
typename XYGridDesc_M_K,
typename MeanVarGridDesc_M,
typename MeanVarCountGridDesc_M_K,
typename ScaleBiasDiffGridDesc_M_G,
typename DscaleDbiasGridDesc_M_G,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
......@@ -154,7 +154,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const XYGridDesc_M_K& dy_grid_desc_m_k,
const MeanVarGridDesc_M& mean_var_grid_desc_m,
const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
const ScaleBiasDiffGridDesc_M_G& dscale_dbias_grid_desc_m_g,
const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
index_t blkgroup_size,
index_t num_xy_k_block_tile_iteration,
index_t num_mean_var_count_k_block_tile_iteration,
......@@ -504,7 +504,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
ScaleDataType,
decltype(thread_buffer_desc_m_1),
ScaleBiasDiffGridDesc_M_G,
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
......@@ -523,7 +523,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
BiasDataType,
decltype(thread_buffer_desc_m_1),
ScaleBiasDiffGridDesc_M_G,
DscaleDbiasGridDesc_M_G,
PassThroughOp,
ThreadBufferLengths_M_1,
Sequence<0, 1>,
......
......@@ -40,8 +40,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const MeanVarDataType* p_savedInvVar,
double epsilon,
DxDataType* p_dx,
ScaleDataType* p_scaleDiff,
BiasDataType* p_biasDiff)
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: p_x_(p_x),
p_dy_(p_dy),
p_scale_(p_scale),
......@@ -49,8 +49,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
p_savedInvVar_(p_savedInvVar),
epsilon_(epsilon),
p_dx_(p_dx),
p_scaleDiff_(p_scaleDiff),
p_biasDiff_(p_biasDiff)
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
ignore = xStrides;
ignore = dyStrides;
......@@ -81,8 +81,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
double epsilon_;
DxDataType* p_dx_;
ScaleDataType* p_scaleDiff_;
BiasDataType* p_biasDiff_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
bool haveSavedMeanInvVar_;
......@@ -113,7 +113,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
meansquare = type_convert<AccDataType>(0.0f);
mean = type_convert<AccDataType>(0.0f);
// compute mean, meanquare, variance, invVariance
// compute mean, meanquare, variance, inv-variance
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
......@@ -142,13 +142,12 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
std::sqrt(type_convert<AccDataType>(arg.epsilon_) + variance);
};
AccDataType bnBiasDiff = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType bnScaleDiff =
type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
AccDataType dbias = type_convert<AccDataType>(0.0f); // Sum on NHW of dy
AccDataType dscale = type_convert<AccDataType>(0.0f); // Sum on NHW of dy * norm_x
// 1) calculate dy * (x - mean) * invVariance
// 2) calculate Sum on NHWC of dy
// 3) calculate Sum on NHWC of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on NHW dimensions
// 3) calculate sum(dy * norm_x) on NHW dimensions
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
......@@ -166,17 +165,17 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType norm_x = (x - mean) * invVar;
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
bnBiasDiff += dy;
bnScaleDiff += norm_x * dy;
dbias += dy;
dscale += norm_x * dy;
};
}
};
arg.p_scaleDiff_[offset_C] = type_convert<ScaleDataType>(bnScaleDiff);
arg.p_biasDiff_[offset_C] = type_convert<BiasDataType>(bnBiasDiff);
arg.p_dscale_[offset_C] = type_convert<ScaleDataType>(dscale);
arg.p_dbias_[offset_C] = type_convert<BiasDataType>(dbias);
// 1) calculate tmp = scaleDiff * (x - mean) * invVariance
// 2) calculate dx = 1/nhw * invVariance * scale * (nhw * dy - biasDiff - tmp)
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/nhw * inv-variance * scale * (nhw * dy - dbias - tmp)
for(index_t iN = 0; iN < arg.n_; iN++)
{
index_t offset_N = iN * arg.h_ * arg.w_ * arg.c_;
......@@ -195,10 +194,10 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType dy = type_convert<AccDataType>(arg.p_dy_[offset]);
AccDataType scale = type_convert<AccDataType>(arg.p_scale_[offset_C]);
AccDataType tmpVal = norm_x * bnScaleDiff;
AccDataType tmpVal = norm_x * dscale;
AccDataType dx = type_convert<AccDataType>(1.0f) / reduceSize * invVar *
scale * (reduceSize * dy - bnBiasDiff - tmpVal);
scale * (reduceSize * dy - dbias - tmpVal);
arg.p_dx_[offset] = type_convert<XDataType>(dx);
};
......@@ -260,8 +259,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const void* p_savedInvVar,
double epsilon,
void* p_dx,
void* p_scaleDiff,
void* p_biasDiff) override
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
......@@ -279,8 +278,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_scaleDiff),
static_cast<BiasDataType*>(p_biasDiff));
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<device::BaseInvoker> MakeInvokerPointer() override
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment