Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
2304e2f0
Commit
2304e2f0
authored
Apr 22, 2024
by
yuguo-Jack
Browse files
fix bn bugs
parent
ca9dbdb2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
136 deletions
+57
-136
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
+20
-39
paddle/phi/kernels/gpu/batch_norm_kernel.cu
paddle/phi/kernels/gpu/batch_norm_kernel.cu
+37
-97
No files found.
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
View file @
2304e2f0
...
@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx,
...
@@ -589,11 +589,11 @@ void BatchNormGradFunctor(const Context &ctx,
auto
dtype
=
phi
::
backends
::
gpu
::
CudnnDataType
<
T
>::
type
;
auto
dtype
=
phi
::
backends
::
gpu
::
CudnnDataType
<
T
>::
type
;
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
auto
compute_format
=
auto
compute_format
=
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
data_layout
==
DataLayout
::
kNHWC
?
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
?
DataLayout
::
kNCHW
:
DataLayout
::
kNHWC
)
:
DataLayout
::
kNCHW
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// HIP do not support compute format of NHWC
// auto compute_format = DataLayout::kNCHW;
// auto compute_format = DataLayout::kNCHW;
#else
#else
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
...
@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx,
...
@@ -762,12 +762,10 @@ void BatchNormGradFunctor(const Context &ctx,
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
ctx
.
template
Alloc
<
T
>(
&
transformed_d_x
),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bn_param_desc_
,
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_scale
),
ctx
.
GetPlace
()),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
epsilon
,
saved_mean_data
,
saved_var_data
));
}
else
{
}
else
{
BNBackward
<
T
,
block
,
DataLayout
::
kNCHW
>
BNBackward
<
T
,
block
,
DataLayout
::
kNCHW
>
...
@@ -786,37 +784,20 @@ void BatchNormGradFunctor(const Context &ctx,
...
@@ -786,37 +784,20 @@ void BatchNormGradFunctor(const Context &ctx,
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
));
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
));
}
}
}
else
{
}
else
{
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
)
{
BNBackward
<
T
,
block
,
DataLayout
::
kNHWC
>
PADDLE_ENFORCE_GPU_SUCCESS
(
<<<
grid2
,
block
,
0
,
ctx
.
stream
()
>>>
(
phi
::
dynload
::
miopenBatchNormalizationBackward
(
transformed_d_y
.
template
data
<
T
>(),
ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
transformed_x
.
template
data
<
T
>(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
saved_mean_data
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
saved_var_data
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
C
,
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
N
,
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
H
*
W
*
D
,
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
epsilon
,
ctx
.
GetPlace
()),
transformed_d_x
.
template
data
<
T
>(),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_scale
),
ctx
.
GetPlace
()),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
));
epsilon
,
saved_mean_data
,
saved_var_data
));
}
else
{
BNBackward
<
T
,
block
,
DataLayout
::
kNHWC
>
<<<
grid2
,
block
,
0
,
ctx
.
stream
()
>>>
(
transformed_d_y
.
template
data
<
T
>(),
transformed_x
.
template
data
<
T
>(),
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_mean_data
,
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_d_x
.
template
data
<
T
>(),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_scale
),
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
d_bias
));
}
}
}
#else
#else
...
...
paddle/phi/kernels/gpu/batch_norm_kernel.cu
View file @
2304e2f0
...
@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx,
...
@@ -572,7 +572,7 @@ void BatchNormKernel(const Context &ctx,
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
auto
compute_format
=
auto
compute_format
=
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
data_layout
==
DataLayout
::
kNHWC
?
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
?
DataLayout
::
kNCHW
:
DataLayout
::
kNHWC
)
:
DataLayout
::
kNCHW
;
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// HIP do not support compute format of NHWC
// HIP do not support compute format of NHWC
...
@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx,
...
@@ -752,12 +752,12 @@ void BatchNormKernel(const Context &ctx,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
data_desc_
,
static_cast
<
void
*>
(
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()
)),
ctx
.
template
Alloc
<
T
>(
&
transformed_y
)),
bn_param_desc_
,
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
scale
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
bias
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
...
@@ -778,43 +778,18 @@ void BatchNormKernel(const Context &ctx,
...
@@ -778,43 +778,18 @@ void BatchNormKernel(const Context &ctx,
transformed_y
.
template
data
<
T
>());
transformed_y
.
template
data
<
T
>());
}
}
}
else
{
}
else
{
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
)
{
BNForwardInference
<
T
,
DataLayout
::
kNHWC
>
PADDLE_ENFORCE_GPU_SUCCESS
(
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
phi
::
dynload
::
miopenBatchNormalizationForwardInference
(
transformed_x
.
template
data
<
T
>(),
handle
,
mode_
,
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
const_cast
<
void
*>
(
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kOne
())),
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
const_cast
<
void
*>
(
new_bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
C
,
data_desc_
,
N
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
H
*
W
*
D
,
data_desc_
,
epsilon
,
static_cast
<
void
*>
(
transformed_y
.
template
data
<
T
>());
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
())),
epsilon
));
}
else
{
BNForwardInference
<
T
,
DataLayout
::
kNHWC
>
<<<
grid_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_var
->
template
data
<
BatchNormParamType
<
T
>
>
(),
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
new_bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
transformed_y
.
template
data
<
T
>());
}
}
}
#else
#else
...
@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx,
...
@@ -943,24 +918,20 @@ void BatchNormKernel(const Context &ctx,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
data_desc_
,
static_cast
<
void
*>
(
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()
)),
ctx
.
template
Alloc
<
T
>(
&
transformed_y
)),
bn_param_desc_
,
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
scale
.
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
new_
bias
.
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
this_factor
,
static_cast
<
void
*>
(
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
mean_out
)),
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
variance_out
)),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
epsilon
,
epsilon
,
static_cast
<
void
*>
(
static_cast
<
void
*>
(
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
saved_mean
)),
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
ctx
.
template
Alloc
<
BatchNormParamType
<
T
>
>
(
saved_variance
))));
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
}
else
{
}
else
{
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNCHW
>
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNCHW
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
...
@@ -979,52 +950,21 @@ void BatchNormKernel(const Context &ctx,
...
@@ -979,52 +950,21 @@ void BatchNormKernel(const Context &ctx,
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
}
}
}
else
{
}
else
{
if
(
FLAGS_cudnn_batchnorm_spatial_persistent
==
true
)
{
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNHWC
>
PADDLE_ENFORCE_GPU_SUCCESS
(
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
phi
::
dynload
::
miopenBatchNormalizationForwardTraining
(
transformed_x
.
template
data
<
T
>(),
handle
,
mode_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
CudnnDataType
<
T
>::
kOne
())),
new_bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
const_cast
<
void
*>
(
C
,
static_cast
<
const
void
*>
(
CudnnDataType
<
T
>::
kZero
())),
N
,
data_desc_
,
H
*
W
*
D
,
static_cast
<
const
void
*>
(
transformed_x
.
template
data
<
T
>()),
data_desc_
,
static_cast
<
void
*>
(
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
())),
bn_param_desc_
,
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale
->
template
data
<
BatchNormParamType
<
T
>
>
())),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias
->
template
data
<
BatchNormParamType
<
T
>
>
())),
this_factor
,
static_cast
<
void
*>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
static_cast
<
void
*>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())),
epsilon
,
epsilon
,
static_cast
<
void
*>
(
this_factor
,
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
transformed_y
.
template
data
<
T
>(),
ctx
.
GetPlace
())),
mean_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
static_cast
<
void
*>
(
saved_variance
->
template
mutable_data
<
variance_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()))));
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
}
else
{
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
BNForwardTraining
<
T
,
block
,
DataLayout
::
kNHWC
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
transformed_x
.
template
data
<
T
>(),
new_scale
.
template
data
<
BatchNormParamType
<
T
>
>
(),
new_bias
.
template
data
<
BatchNormParamType
<
T
>
>
(),
C
,
N
,
H
*
W
*
D
,
epsilon
,
this_factor
,
transformed_y
.
template
data
<
T
>(),
mean_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
variance_out
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
saved_variance
->
template
data
<
BatchNormParamType
<
T
>
>
());
}
}
}
#else
#else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment