Commit 2304e2f0 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

fix bn bugs

parent ca9dbdb2
......@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx,
auto dtype = phi::backends::gpu::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP
auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW;
data_layout == DataLayout::kNHWC ? (FLAGS_cudnn_batchnorm_spatial_persistent == true ? DataLayout::kNCHW : DataLayout::kNHWC) : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
// auto compute_format = DataLayout::kNCHW;
#else
const bool fast_nhwc_batch_norm = dtype == CUDNN_DATA_HALF &&
FLAGS_cudnn_batchnorm_spatial_persistent &&
......@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx,
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
ctx.template Alloc<T>(&transformed_d_x),
bn_param_desc_, new_scale.template data<BatchNormParamType<T>>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
epsilon, saved_mean_data, saved_var_data));
} else {
BNBackward<T, block, DataLayout::kNCHW>
......@@ -786,37 +784,20 @@ void BatchNormGradFunctor(const Context &ctx,
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
} else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenBatchNormalizationBackward(
ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_,
transformed_x.template data<T>(), data_desc_,
transformed_d_y.template data<T>(), data_desc_,
transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
new_scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
new_scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
#else
......
......@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx,
#ifdef PADDLE_WITH_HIP
auto compute_format =
data_layout == DataLayout::kNHWC ? DataLayout::kNHWC : DataLayout::kNCHW;
data_layout == DataLayout::kNHWC ? (FLAGS_cudnn_batchnorm_spatial_persistent == true ? DataLayout::kNCHW : DataLayout::kNHWC) : DataLayout::kNCHW;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
......@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
ctx.template Alloc<T>(&transformed_y)),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
new_scale.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
new_bias.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
......@@ -778,43 +778,18 @@ void BatchNormKernel(const Context &ctx,
transformed_y.template data<T>());
}
} else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenBatchNormalizationForwardInference(
handle, mode_,
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_mean->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
est_var->template data<BatchNormParamType<T>>())),
epsilon));
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
new_scale.template data<BatchNormParamType<T>>(),
new_bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
new_scale.template data<BatchNormParamType<T>>(),
new_bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
#else
......@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
ctx.template Alloc<T>(&transformed_y)),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
new_scale.template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
new_bias.template data<BatchNormParamType<T>>())),
this_factor,
static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
ctx.template Alloc<BatchNormParamType<T>>(mean_out)),
static_cast<void *>(ctx.template Alloc<BatchNormParamType<T>>(variance_out)),
epsilon,
static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
ctx.template Alloc<BatchNormParamType<T>>(saved_mean)),
static_cast<void *>(ctx.template Alloc<BatchNormParamType<T>>(saved_variance))));
} else {
BNForwardTraining<T, block, DataLayout::kNCHW>
<<<grid, block, 0, ctx.stream()>>>(
......@@ -979,52 +950,21 @@ void BatchNormKernel(const Context &ctx,
saved_variance->template data<BatchNormParamType<T>>());
}
} else {
if (FLAGS_cudnn_batchnorm_spatial_persistent == true) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::miopenBatchNormalizationForwardTraining(
handle, mode_, const_cast<void *>(static_cast<const void *>(
CudnnDataType<T>::kOne())),
const_cast<void *>(
static_cast<const void *>(CudnnDataType<T>::kZero())),
data_desc_,
static_cast<const void *>(transformed_x.template data<T>()),
data_desc_,
static_cast<void *>(
transformed_y.template mutable_data<T>(ctx.GetPlace())),
bn_param_desc_,
const_cast<void *>(static_cast<const void *>(
scale->template data<BatchNormParamType<T>>())),
const_cast<void *>(static_cast<const void *>(
bias->template data<BatchNormParamType<T>>())),
this_factor,
static_cast<void *>(
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(variance_out->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace())),
BNForwardTraining<T, block, DataLayout::kNHWC>
<<<grid, block, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
new_scale.template data<BatchNormParamType<T>>(),
new_bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
static_cast<void *>(
saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())),
static_cast<void *>(saved_variance->template mutable_data<
BatchNormParamType<T>>(ctx.GetPlace()))));
} else {
BNForwardTraining<T, block, DataLayout::kNHWC>
<<<grid, block, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
new_scale.template data<BatchNormParamType<T>>(),
new_bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
this_factor,
transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
this_factor,
transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
#else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment