Commit 2304e2f0 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

fix bn bugs

parent ca9dbdb2
...@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx,
auto dtype = phi::backends::gpu::CudnnDataType<T>::type; auto dtype = phi::backends::gpu::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
auto compute_format = auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; data_layout == DataLayout::kNHWC ? (FLAGS_cudnn_batchnorm_spatial_persistent == true ? DataLayout::kNCHW : DataLayout::kNHWC) : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC // HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW; // auto compute_format = DataLayout::kNCHW;
#else #else
const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF && const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF &&
FLAGS_cudnn_batchnorm_spatial_persistent && FLAGS_cudnn_batchnorm_spatial_persistent &&
...@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx,
CudnnDataType<T>::kZero(), data_desc_, CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_, transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_, transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()), ctx.template Alloc<T>(&transformed_d_x),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(), bn_param_desc_, new_scale.template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(d_bias),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data)); epsilon, saved_mean_data, saved_var_data));
} else { } else {
BNBackward<T, block, DataLayout::kNCHW> BNBackward<T, block, DataLayout::kNCHW>
...@@ -786,37 +784,20 @@ void BatchNormGradFunctor(const Context &ctx, ...@@ -786,37 +784,20 @@ void BatchNormGradFunctor(const Context &ctx,
ctx.template Alloc<BatchNormParamType<T>>(d_bias)); ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} }
} else { } else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) { BNBackward<T, block, DataLayout::kNHWC>
PADDLE_ENFORCE_GPU_SUCCESS( <<<grid2, block, 0, ctx.stream()>>>(
phi::dynload::miopenBatchNormalizationBackward( transformed_d_y.template data<T>(),
ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(), transformed_x.template data<T>(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(), new_scale.template data<BatchNormParamType<T>>(),
CudnnDataType<T>::kZero(), data_desc_, saved_mean_data,
transformed_x.template data<T>(), data_desc_, saved_var_data,
transformed_d_y.template data<T>(), data_desc_, C,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()), N,
bn_param_desc_, scale->template data<BatchNormParamType<T>>(), H * W * D,
d_scale->template mutable_data<BatchNormParamType<T>>( epsilon,
ctx.GetPlace()), transformed_d_x.template data<T>(),
d_bias->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(d_bias));
epsilon, saved_mean_data, saved_var_data));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
new_scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
} }
#else #else
......
...@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx,
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
auto compute_format = auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW; data_layout == DataLayout::kNHWC ? (FLAGS_cudnn_batchnorm_spatial_persistent == true ? DataLayout::kNCHW : DataLayout::kNHWC) : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC // HIP do not support compute format of NHWC
...@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx, ...@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx,
static_cast<const void *>(transformed_x.template data<T>()), static_cast<const void *>(transformed_x.template data<T>()),
data_desc_, data_desc_,
static_cast<void *>( static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())), ctx.template Alloc<T>(&transformed_y)),
bn_param_desc_, bn_param_desc_,
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())), new_scale.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())), new_bias.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())), est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
...@@ -778,43 +778,18 @@ void BatchNormKernel(const Context &ctx, ...@@ -778,43 +778,18 @@ void BatchNormKernel(const Context &ctx,
transformed_y.template data<T>()); transformed_y.template data<T>());
} }
} else { } else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) { BNForwardInference<T, DataLayout::kNHWC>
PADDLE_ENFORCE_GPU_SUCCESS( <<<grid_size, block_size, 0, ctx.stream()>>>(
phi::dynload::miopenBatchNormalizationForwardInference( transformed_x.template data<T>(),
handle, mode_, est_mean->template data<BatchNormParamType<T>>(),
const_cast<void *>( est_var->template data<BatchNormParamType<T>>(),
static_cast<const void *>(CudnnDataType<T>::kOne())), new_scale.template data<BatchNormParamType<T>>(),
const_cast<void *>( new_bias.template data<BatchNormParamType<T>>(),
static_cast<const void *>(CudnnDataType<T>::kZero())), C,
data_desc_, N,
static_cast<const void *>(transformed_x.template data<T>()), H * W * D,
data_desc_, epsilon,
static_cast<void *>( transformed_y.template data<T>());
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_var->template data<BatchNormParamType<T>>())),
epsilon));
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
new_scale.template data<BatchNormParamType<T>>(),
new_bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
} }
#else #else
...@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx, ...@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx,
static_cast<const void *>(transformed_x.template data<T>()), static_cast<const void *>(transformed_x.template data<T>()),
data_desc_, data_desc_,
static_cast<void *>( static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())), ctx.template Alloc<T>(&transformed_y)),
bn_param_desc_, bn_param_desc_,
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())), new_scale.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>( const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())), new_bias.template data<BatchNormParamType<T>>())),
this_factor, this_factor,
static_cast<void *>( static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(mean_out)),
ctx.GetPlace())), static_cast<void *>(ctx.template Alloc<BatchNormParamType<T>>(variance_out)),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
epsilon, epsilon,
static_cast<void *>( static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(saved_mean)),
ctx.GetPlace())), static_cast<void *>(ctx.template Alloc<BatchNormParamType<T>>(saved_variance))));
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
} else { } else {
BNForwardTraining<T, block, DataLayout::kNCHW> BNForwardTraining<T, block, DataLayout::kNCHW>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
...@@ -979,52 +950,21 @@ void BatchNormKernel(const Context &ctx, ...@@ -979,52 +950,21 @@ void BatchNormKernel(const Context &ctx,
saved_variance->template data<BatchNormParamType<T>>()); saved_variance->template data<BatchNormParamType<T>>());
} }
} else { } else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) { BNForwardTraining<T, block, DataLayout::kNHWC>
PADDLE_ENFORCE_GPU_SUCCESS( <<<grid, block, 0, ctx.stream()>>>(
phi::dynload::miopenBatchNormalizationForwardTraining( transformed_x.template data<T>(),
handle, mode_, const_cast<void *>(static_cast<const void *>( new_scale.template data<BatchNormParamType<T>>(),
CudnnDataType<T>::kOne())), new_bias.template data<BatchNormParamType<T>>(),
const_cast<void *>( C,
static_cast<const void *>(CudnnDataType<T>::kZero())), N,
data_desc_, H * W * D,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
this_factor,
static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
epsilon, epsilon,
static_cast<void *>( this_factor,
saved_mean->template mutable_data<BatchNormParamType<T>>( transformed_y.template data<T>(),
ctx.GetPlace())), mean_out->template data<BatchNormParamType<T>>(),
static_cast<void *>(saved_variance->template mutable_data< variance_out->template data<BatchNormParamType<T>>(),
BatchNormParamType<T>>(ctx.GetPlace())))); saved_mean->template data<BatchNormParamType<T>>(),
} else { saved_variance->template data<BatchNormParamType<T>>());
BNForwardTraining<T, block, DataLayout::kNHWC>
<<<grid, block, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
new_scale.template data<BatchNormParamType<T>>(),
new_bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
this_factor,
transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
} }
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment