Commit 3f392b53 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Parameters renaming in batchnorm backward kernels and device op

parent d0b49a14
......@@ -35,8 +35,8 @@ struct DeviceBatchNormBwd : public BaseOperator
const void* p_savedInvVar,
double epsilon,
void* p_dx,
void* p_scaleDiff,
void* p_biasDiff) = 0;
void* p_dscale,
void* p_dbias) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -201,8 +201,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const MeanVarDataType* p_savedInvVar,
double epsilon,
DxDataType* p_dx,
ScaleDataType* p_scaleDiff,
BiasDataType* p_biasDiff)
ScaleDataType* p_dscale,
BiasDataType* p_dbias)
: bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
bnScaleStrides_(bnScaleStrides),
bnBiasStrides_(bnBiasStrides),
......@@ -213,8 +213,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
p_savedMean_(p_savedMean),
p_savedInvVar_(p_savedInvVar),
p_dx_(p_dx),
p_scaleDiff_(p_scaleDiff),
p_biasDiff_(p_biasDiff)
p_dscale_(p_dscale),
p_dbias_(p_dbias)
{
xyLengths_ =
shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
......@@ -294,8 +294,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const MeanVarDataType* p_savedMean_;
const MeanVarDataType* p_savedInvVar_;
DxDataType* p_dx_;
ScaleDataType* p_scaleDiff_;
BiasDataType* p_biasDiff_;
ScaleDataType* p_dscale_;
BiasDataType* p_dbias_;
long_index_t invariant_length;
long_index_t reduce_length;
......@@ -318,8 +318,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
void* workspace_savedMean;
void* workspace_savedInvVar;
void* workspace_reduce_scale_diff;
void* workspace_reduce_bias_diff;
void* workspace_reduce_dscale;
void* workspace_reduce_dbias;
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
......@@ -372,14 +372,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
index_t space_sz;
// setup buffer for the partial reduced result for scale_diff
pArg_->workspace_reduce_scale_diff = pArg_->p_workspace_;
pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(ScaleDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
// setup buffer for the partial reduced result for bias_diff
pArg_->workspace_reduce_bias_diff =
reinterpret_cast<char*>(pArg_->workspace_reduce_scale_diff) + space_sz;
pArg_->workspace_reduce_dbias =
reinterpret_cast<char*>(pArg_->workspace_reduce_dscale) + space_sz;
if(UseMultiblockInK && pArg_->blkGroupSize > 1)
{
......@@ -388,7 +388,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
// setup buffer for welford intermediate mean
pArg_->workspace_mean =
reinterpret_cast<char*>(pArg_->workspace_reduce_bias_diff) + space_sz;
reinterpret_cast<char*>(pArg_->workspace_reduce_dbias) + space_sz;
space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
space_sz = math::integer_least_multiple(space_sz, 64);
......@@ -604,8 +604,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
: static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
arg.p_x_,
arg.p_dy_,
static_cast<ScaleDataType*>(arg.workspace_reduce_scale_diff),
static_cast<BiasDataType*>(arg.workspace_reduce_bias_diff));
static_cast<ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<BiasDataType*>(arg.workspace_reduce_dbias));
avg_time += launch_and_time_kernel(
stream_config,
......@@ -624,8 +624,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.reduce_length,
arg.numBlockTileIteration,
numScaleBiasDiffBlockTileIteration,
static_cast<const ScaleDataType*>(arg.workspace_reduce_scale_diff),
static_cast<const BiasDataType*>(arg.workspace_reduce_bias_diff),
static_cast<const ScaleDataType*>(arg.workspace_reduce_dscale),
static_cast<const BiasDataType*>(arg.workspace_reduce_dbias),
arg.haveSavedMeanInvVar_
? arg.p_savedMean_
: static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
......@@ -636,8 +636,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.p_dy_,
arg.p_scale_,
arg.p_dx_,
arg.p_scaleDiff_,
arg.p_biasDiff_);
arg.p_dscale_,
arg.p_dbias_);
}
else
{
......@@ -708,8 +708,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg.p_savedMean_,
arg.p_savedInvVar_,
arg.p_dx_,
arg.p_scaleDiff_,
arg.p_biasDiff_);
arg.p_dscale_,
arg.p_dbias_);
};
return (avg_time);
......@@ -801,8 +801,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const void* p_savedInvVar,
double epsilon,
void* p_dx,
void* p_scaleDiff,
void* p_biasDiff) override
void* p_dscale,
void* p_dbias) override
{
return std::make_unique<Argument>(xyLengths,
xStrides,
......@@ -820,8 +820,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
static_cast<const MeanVarDataType*>(p_savedInvVar),
epsilon,
static_cast<DxDataType*>(p_dx),
static_cast<ScaleDataType*>(p_scaleDiff),
static_cast<BiasDataType*>(p_biasDiff));
static_cast<ScaleDataType*>(p_dscale),
static_cast<BiasDataType*>(p_dbias));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
......@@ -34,16 +34,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_scale_diff,
const BiasDataType* const __restrict__ p_reduce_bias_diff,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff,
BiasDataType* const __restrict__ p_bias_diff)
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -56,16 +56,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
reduce_size,
num_xy_k_block_tile_iteration,
num_scale_bias_diff_k_block_tile_iteration,
p_reduce_scale_diff,
p_reduce_bias_diff,
p_reduce_dscale,
p_reduce_dbias,
p_mean,
p_inv_var,
p_x,
p_dy,
p_scale,
p_dx,
p_scale_diff,
p_bias_diff);
p_dscale,
p_dbias);
};
template <typename XDataType,
......@@ -151,16 +151,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
long_index_t reduce_size,
index_t num_xy_k_block_tile_iteration,
index_t num_scale_bias_diff_k_block_tile_iteration,
const ScaleDataType* const __restrict__ p_reduce_scale_diff,
const BiasDataType* const __restrict__ p_reduce_bias_diff,
const ScaleDataType* const __restrict__ p_reduce_dscale,
const BiasDataType* const __restrict__ p_reduce_dbias,
const MeanVarDataType* const __restrict__ p_mean,
const MeanVarDataType* const __restrict__ p_inv_var,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
const ScaleDataType* const __restrict__ p_scale,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff,
BiasDataType* const __restrict__ p_bias_diff)
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -281,16 +281,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
PassThroughOp{});
const auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_scale_diff, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
p_reduce_dscale, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
const auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_bias_diff, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
p_reduce_dbias, scale_bias_diff_grid_desc_m_k.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale_diff, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_diff, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
constexpr auto scale_bias_diff_thread_copy_step_m_k =
make_multi_index(0, KThreadClusterSize * 1);
......
......@@ -43,8 +43,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_scale_diff,
BiasDataType* const __restrict__ p_reduce_bias_diff)
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
{
GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -65,8 +65,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
p_out_welford_inv_variance,
p_x,
p_dy,
p_reduce_scale_diff,
p_reduce_bias_diff);
p_reduce_dscale,
p_reduce_dbias);
};
template <typename XDataType,
......@@ -164,8 +164,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
const XDataType* const __restrict__ p_x,
const DyDataType* const __restrict__ p_dy,
ScaleDataType* const __restrict__ p_reduce_scale_diff,
BiasDataType* const __restrict__ p_reduce_bias_diff)
ScaleDataType* const __restrict__ p_reduce_dscale,
BiasDataType* const __restrict__ p_reduce_dbias)
{
__shared__ AccDataType p_reduce_work_buffer[BlockSize];
......@@ -531,10 +531,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
PassThroughOp{});
auto reduce_scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_scale_diff, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
p_reduce_dscale, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
auto reduce_bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_bias_diff, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
p_reduce_dbias, scale_bias_diff_grid_desc_m_g.GetElementSpaceSize());
if(thread_k_cluster_id == 0)
{
......
......@@ -45,8 +45,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff,
BiasDataType* const __restrict__ p_bias_diff)
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
dy_grid_desc_m_k,
......@@ -65,8 +65,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
p_savedMean,
p_savedInvVar,
p_dx,
p_scale_diff,
p_bias_diff);
p_dscale,
p_dbias);
};
template <typename XDataType,
......@@ -166,8 +166,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const MeanVarDataType* const __restrict__ p_savedMean,
const MeanVarDataType* const __restrict__ p_savedInvVar,
DxDataType* const __restrict__ p_dx,
ScaleDataType* const __restrict__ p_scale_diff,
BiasDataType* const __restrict__ p_bias_diff)
ScaleDataType* const __restrict__ p_dscale,
BiasDataType* const __restrict__ p_dbias)
{
using ck::math::sqrt;
......@@ -333,10 +333,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale, scale_grid_desc_m.GetElementSpaceSize());
auto scale_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_scale_diff, scale_grid_desc_m.GetElementSpaceSize());
p_dscale, scale_grid_desc_m.GetElementSpaceSize());
auto bias_diff_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_diff, bias_grid_desc_m.GetElementSpaceSize());
p_dbias, bias_grid_desc_m.GetElementSpaceSize());
if(haveSavedMeanInvVar)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment