"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "5b380b89b5104ded935efd86dae0bfe94e106d03"
Commit 2304e2f0 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

fix bn bugs

parent ca9dbdb2
...@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx,
auto dtype = phi::backends::gpu::CudnnDataType<T>::type; auto dtype = phi::backends::gpu::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
auto compute_format = auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; data_layout == DataLayout::kNHWC ? (FLAGS_cudnn_batchnorm_spatial_persistent == true ? DataLayout::kNCHW : DataLayout::kNHWC) : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC // HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW; // auto compute_format = DataLayout::kNCHW;
#else #else
const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF && const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF &&
FLAGS_cudnn_batchnorm_spatial_persistent && FLAGS_cudnn_batchnorm_spatial_persistent &&
...@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx,
CudnnDataType<T>::kZero(), data_desc_, CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_, transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_, transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()), ctx.template Alloc<T>(&transformed_d_x),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(), bn_param_desc_, new_scale.template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(d_bias),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data)); epsilon, saved_mean_data, saved_var_data));
} else { } else {
BNBackward<T, block, DataLayout::kNCHW> BNBackward<T, block, DataLayout::kNCHW>
...@@ -785,22 +783,6 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -785,22 +783,6 @@ void BatchNormGradFunctor(const Context &ctx,
ctx.template Alloc<BatchNormParamType<T>>(d_scale), ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias)); ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} }
} else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenBatchNormalizationBackward(
ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
} else { } else {
BNBackward<T, block, DataLayout::kNHWC> BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>( <<<grid2, block, 0, ctx.stream()>>>(
...@@ -817,7 +799,6 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -817,7 +799,6 @@ void BatchNormGradFunctor(const Context &ctx,
ctx.template Alloc<BatchNormParamType<T>>(d_scale), ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias)); ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} }
}
#else #else
} }
......
...@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx,
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
auto compute_format = auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; data_layout == DataLayout::kNHWC ? (FLAGS_cudnn_batchnorm_spatial_persistent == true ? DataLayout::kNCHW : DataLayout::kNHWC) : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC // HIP do not support compute format of NHWC
...@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx, ...@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx,
static_cast<const void *>(transformed_x.template data<T>()), static_cast<const void *>(transformed_x.template data<T>()),
data_desc_, data_desc_,
static_cast<void *>( static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())), ctx.template Alloc<T>(&transformed_y)),
bn_param_desc_, bn_param_desc_,
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())), new_scale.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())), new_bias.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())), est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
...@@ -777,30 +777,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -777,30 +777,6 @@ void BatchNormKernel(const Context &ctx,
epsilon, epsilon,
transformed_y.template data<T>()); transformed_y.template data<T>());
} }
} else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenBatchNormalizationForwardInference(
handle, mode_,
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_var->template data<BatchNormParamType<T>>())),
epsilon));
} else { } else {
BNForwardInference<T, DataLayout::kNHWC> BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>( <<<grid_size, block_size, 0, ctx.stream()>>>(
...@@ -815,7 +791,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -815,7 +791,6 @@ void BatchNormKernel(const Context &ctx,
epsilon, epsilon,
transformed_y.template data<T>()); transformed_y.template data<T>());
} }
}
#else #else
const bool use_native_kernel = const bool use_native_kernel =
...@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx, ...@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx,
static_cast<const void *>(transformed_x.template data<T>()), static_cast<const void *>(transformed_x.template data<T>()),
data_desc_, data_desc_,
static_cast<void *>( static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())), ctx.template Alloc<T>(&transformed_y)),
bn_param_desc_, bn_param_desc_,
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())), new_scale.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())), new_bias.template data<BatchNormParamType<T>>())),
this_factor, this_factor,
static_cast<void *>( static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(mean_out)),
ctx.GetPlace())), static_cast<void *>(ctx.template Alloc<BatchNormParamType<T>>(variance_out)),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
epsilon, epsilon,
static_cast<void *>( static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(saved_mean)),
ctx.GetPlace())), static_cast<void *>(ctx.template Alloc<BatchNormParamType<T>>(saved_variance))));
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
} else { } else {
BNForwardTraining<T, block, DataLayout::kNCHW> BNForwardTraining<T, block, DataLayout::kNCHW>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
...@@ -978,36 +949,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -978,36 +949,6 @@ void BatchNormKernel(const Context &ctx,
saved_mean->template data<BatchNormParamType<T>>(), saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>()); saved_variance->template data<BatchNormParamType<T>>());
} }
} else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenBatchNormalizationForwardTraining(
handle, mode_, const_cast<void *>(static_cast<const void *>(
CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
this_factor,
static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
epsilon,
static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
} else { } else {
BNForwardTraining<T, block, DataLayout::kNHWC> BNForwardTraining<T, block, DataLayout::kNHWC>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
...@@ -1025,7 +966,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -1025,7 +966,6 @@ void BatchNormKernel(const Context &ctx,
saved_mean->template data<BatchNormParamType<T>>(), saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>()); saved_variance->template data<BatchNormParamType<T>>());
} }
}
#else #else
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment