Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
2304e2f0
Commit
2304e2f0
authored
Apr 22, 2024
by
yuguo-Jack
Browse files
fix bn bugs
parent
ca9dbdb2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
136 deletions
+57
-136
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
+20
-39
paddle/phi/kernels/gpu/batch_norm_kernel.cu
paddle/phi/kernels/gpu/batch_norm_kernel.cu
+37
-97
No files found.
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
View file @
2304e2f0
...
...
@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx,
auto
dtype
=
phi
::
backends
::
gpu
::
CudnnDataType
<
T
>::
type
;
#ifdef PADDLE_WITH_HIP
auto
compute_format
=
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
data_layout
==
DataLayout
::
kNHWC
?
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
?
DataLayout
::
kNCHW
:
DataLayout
::
kNHWC
)
:
DataLayout
::
kNCHW
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
// auto compute_format = DataLayout::kNCHW;
#else
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
...
...
@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx,
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
template
Alloc
<
T
>(
&
transformed_d_x
),
bn_param_desc_
,
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_scale
),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
),
epsilon
,
saved_mean_data
,
saved_var_data
));
}
else
{
BNBackward
<
T
,
block
,
DataLayout
::
kNCHW
>
...
...
@@ -785,22 +783,6 @@ void BatchNormGradFunctor(const Context &ctx,
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_scale
),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
));
}
}
else
{
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
miopenBatchNormalizationBackward
(
ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
}
else
{
BNBackward
<
T
,
block
,
DataLayout
::
kNHWC
>
<<<
grid2
,
block
,
0
,
ctx
.
stream
()
>>>
(
...
...
@@ -817,7 +799,6 @@ void BatchNormGradFunctor(const Context &ctx,
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_scale
),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
));
}
}
#else
}
...
...
paddle/phi/kernels/gpu/batch_norm_kernel.cu
View file @
2304e2f0
...
...
@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx,
#ifdef PADDLE_WITH_HIP
auto
compute_format
=
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
data_layout
==
DataLayout
::
kNHWC
?
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
?
DataLayout
::
kNCHW
:
DataLayout
::
kNHWC
)
:
DataLayout
::
kNCHW
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
...
...
@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()
)),
ctx
.
template
Alloc
<
T
>(
&
transformed_y
)),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
scale
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
bias
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
...
...
@@ -777,30 +777,6 @@ void BatchNormKernel(const Context &ctx,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
}
else
{
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
miopenBatchNormalizationForwardInference
(
handle
,
mode_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
())),
epsilon
));
}
else
{
BNForwardInference
<
T
,
DataLayout
::
kNHWC
>
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
...
...
@@ -815,7 +791,6 @@ void BatchNormKernel(const Context &ctx,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
}
#else
const
bool
use_native_kernel
=
...
...
@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()
)),
ctx
.
template
Alloc
<
T
>(
&
transformed_y
)),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
scale
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
bias
.
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
mean_out
)),
static_cast
<
void
*>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
variance_out
)),
epsilon
,
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
saved_mean
)),
static_cast
<
void
*>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
saved_variance
))));
}
else
{
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNCHW
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
...
...
@@ -978,36 +949,6 @@ void BatchNormKernel(const Context &ctx,
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
}
}
else
{
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
miopenBatchNormalizationForwardTraining
(
handle
,
mode_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
data_desc_
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
epsilon
,
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
}
else
{
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNHWC
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
...
...
@@ -1025,7 +966,6 @@ void BatchNormKernel(const Context &ctx,
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
}
}
#else
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment