Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
59613285
Commit
59613285
authored
Nov 03, 2022
by
Qianfeng Zhang
Browse files
Add dy_elementwise_op
parent
f4d67cf8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
69 additions
and
7 deletions
+69
-7
example/34_batchnorm/batchnorm_backward_nhwc.cpp
example/34_batchnorm/batchnorm_backward_nhwc.cpp
+7
-1
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
...tensor_operation/gpu/device/device_batchnorm_backward.hpp
+5
-3
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
...ration/gpu/device/impl/device_batchnorm_backward_impl.hpp
+17
-1
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+8
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
...lock_welford_second_half_multiblock_reduce_first_half.hpp
+8
-0
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
...pu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
+11
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp
...sor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp
+13
-2
No files found.
example/34_batchnorm/batchnorm_backward_nhwc.cpp
View file @
59613285
...
@@ -253,6 +253,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -253,6 +253,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
scaleBiasMeanVarStrides
.
end
(),
scaleBiasMeanVarStrides
.
end
(),
i_scaleBiasMeanVarStrides
.
begin
());
i_scaleBiasMeanVarStrides
.
begin
());
using
PassThroughOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceBatchNormBwdInstance
=
using
DeviceBatchNormBwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchNormBwdImpl
<
InOutDataType
,
ck
::
tensor_operation
::
device
::
DeviceBatchNormBwdImpl
<
InOutDataType
,
InOutDataType
,
InOutDataType
,
...
@@ -261,6 +263,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -261,6 +263,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
AccDataType
,
// ScaleDataType
AccDataType
,
// ScaleDataType
AccDataType
,
// BiasDataType
AccDataType
,
// BiasDataType
AccDataType
,
// MeanVarDataType
AccDataType
,
// MeanVarDataType
PassThroughOp
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
UseMultiblockInK
,
UseMultiblockInK
,
...
@@ -295,6 +298,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -295,6 +298,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
haveSavedMeanInvVar
?
savedMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
haveSavedMeanInvVar
?
savedMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar_dev
.
GetDeviceBuffer
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar_dev
.
GetDeviceBuffer
()
:
nullptr
,
epsilon
,
epsilon
,
PassThroughOp
{},
dx_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
(),
bnScaleDiff_dev
.
GetDeviceBuffer
(),
bnScaleDiff_dev
.
GetDeviceBuffer
(),
bnBiasDiff_dev
.
GetDeviceBuffer
());
bnBiasDiff_dev
.
GetDeviceBuffer
());
...
@@ -350,7 +354,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -350,7 +354,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
;
AccDataType
,
PassThroughOp
>
;
auto
batchNormBwd_ref
=
ReferenceBatchNormBwdInstance
{};
auto
batchNormBwd_ref
=
ReferenceBatchNormBwdInstance
{};
...
@@ -370,6 +375,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -370,6 +375,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
haveSavedMeanInvVar
?
savedMean
.
mData
.
data
()
:
nullptr
,
haveSavedMeanInvVar
?
savedMean
.
mData
.
data
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar
.
mData
.
data
()
:
nullptr
,
haveSavedMeanInvVar
?
savedInvVar
.
mData
.
data
()
:
nullptr
,
epsilon
,
epsilon
,
PassThroughOp
{},
dx_ref
.
mData
.
data
(),
dx_ref
.
mData
.
data
(),
bnScaleDiff_ref
.
mData
.
data
(),
bnScaleDiff_ref
.
mData
.
data
(),
bnBiasDiff_ref
.
mData
.
data
());
bnBiasDiff_ref
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
View file @
59613285
...
@@ -13,7 +13,7 @@ namespace ck {
...
@@ -13,7 +13,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
DyElementwiseOp
>
struct
DeviceBatchNormBwd
:
public
BaseOperator
struct
DeviceBatchNormBwd
:
public
BaseOperator
{
{
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
...
@@ -34,6 +34,7 @@ struct DeviceBatchNormBwd : public BaseOperator
...
@@ -34,6 +34,7 @@ struct DeviceBatchNormBwd : public BaseOperator
const
void
*
p_savedMean
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
const
void
*
p_savedInvVar
,
double
epsilon
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dscale
,
void
*
p_dbias
)
=
0
;
void
*
p_dbias
)
=
0
;
...
@@ -41,8 +42,9 @@ struct DeviceBatchNormBwd : public BaseOperator
...
@@ -41,8 +42,9 @@ struct DeviceBatchNormBwd : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
DyElementwiseOp
>
using
DeviceBatchNormBwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormBwd
<
Rank
,
NumBatchNormReduceDim
>>
;
using
DeviceBatchNormBwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormBwd
<
Rank
,
NumBatchNormReduceDim
,
DyElementwiseOp
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
View file @
59613285
...
@@ -29,6 +29,7 @@ template <typename XDataType,
...
@@ -29,6 +29,7 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
index_t
Rank
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
bool
UseMultiblockInK
,
...
@@ -44,7 +45,8 @@ template <typename XDataType,
...
@@ -44,7 +45,8 @@ template <typename XDataType,
index_t
ScaleSrcDstVectorSize
,
index_t
ScaleSrcDstVectorSize
,
index_t
BiasDstVectorSize
,
index_t
BiasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
index_t
MeanVarSrcVectorSize
>
struct
DeviceBatchNormBwdImpl
:
public
DeviceBatchNormBwd
<
Rank
,
NumBatchNormReduceDim
>
struct
DeviceBatchNormBwdImpl
:
public
DeviceBatchNormBwd
<
Rank
,
NumBatchNormReduceDim
,
DyElementwiseOp
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
@@ -199,6 +201,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -199,6 +201,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const
ScaleDataType
*
p_scale
,
const
ScaleDataType
*
p_scale
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
const
MeanVarDataType
*
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
double
epsilon
,
double
epsilon
,
DxDataType
*
p_dx
,
DxDataType
*
p_dx
,
ScaleDataType
*
p_dscale
,
ScaleDataType
*
p_dscale
,
...
@@ -212,6 +215,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -212,6 +215,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
p_scale_
(
p_scale
),
p_scale_
(
p_scale
),
p_savedMean_
(
p_savedMean
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
p_savedInvVar_
(
p_savedInvVar
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
p_dbias_
(
p_dbias
)
...
@@ -293,6 +297,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -293,6 +297,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const
ScaleDataType
*
p_scale_
;
const
ScaleDataType
*
p_scale_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
const
MeanVarDataType
*
p_savedInvVar_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
DxDataType
*
p_dx_
;
ScaleDataType
*
p_dscale_
;
ScaleDataType
*
p_dscale_
;
BiasDataType
*
p_dbias_
;
BiasDataType
*
p_dbias_
;
...
@@ -451,6 +456,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -451,6 +456,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType
,
ScaleDataType
,
BiasDataType
,
BiasDataType
,
MeanVarDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
XYGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
...
@@ -473,6 +479,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -473,6 +479,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType
,
ScaleDataType
,
BiasDataType
,
BiasDataType
,
MeanVarDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
XYGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarGridDesc_M
,
...
@@ -548,6 +555,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -548,6 +555,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType
,
ScaleDataType
,
BiasDataType
,
BiasDataType
,
MeanVarDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
XYGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
...
@@ -562,6 +570,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -562,6 +570,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType
,
ScaleDataType
,
BiasDataType
,
BiasDataType
,
MeanVarDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
XYGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarGridDesc_M
,
...
@@ -596,6 +605,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -596,6 +605,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_variance
),
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_variance
),
arg
.
haveSavedMeanInvVar_
?
nullptr
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
int32_t
*>
(
arg
.
workspace_count
),
:
static_cast
<
const
int32_t
*>
(
arg
.
workspace_count
),
arg
.
dy_elementwise_op_
,
arg
.
haveSavedMeanInvVar_
arg
.
haveSavedMeanInvVar_
?
nullptr
?
nullptr
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
...
@@ -635,6 +645,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -635,6 +645,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg
.
p_x_
,
arg
.
p_x_
,
arg
.
p_dy_
,
arg
.
p_dy_
,
arg
.
p_scale_
,
arg
.
p_scale_
,
arg
.
dy_elementwise_op_
,
arg
.
p_dx_
,
arg
.
p_dx_
,
arg
.
p_dscale_
,
arg
.
p_dscale_
,
arg
.
p_dbias_
);
arg
.
p_dbias_
);
...
@@ -655,6 +666,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -655,6 +666,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType
,
ScaleDataType
,
BiasDataType
,
BiasDataType
,
MeanVarDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
XYGridDesc_M_K
,
ScaleBiasGridDesc_M
,
ScaleBiasGridDesc_M
,
MeanVarGridDesc_M
,
MeanVarGridDesc_M
,
...
@@ -681,6 +693,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -681,6 +693,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
ScaleDataType
,
ScaleDataType
,
BiasDataType
,
BiasDataType
,
MeanVarDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
XYGridDesc_M_K
,
ScaleBiasGridDesc_M
,
ScaleBiasGridDesc_M
,
MeanVarGridDesc_M
,
MeanVarGridDesc_M
,
...
@@ -707,6 +720,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -707,6 +720,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg
.
haveSavedMeanInvVar_
,
arg
.
haveSavedMeanInvVar_
,
arg
.
p_savedMean_
,
arg
.
p_savedMean_
,
arg
.
p_savedInvVar_
,
arg
.
p_savedInvVar_
,
arg
.
dy_elementwise_op_
,
arg
.
p_dx_
,
arg
.
p_dx_
,
arg
.
p_dscale_
,
arg
.
p_dscale_
,
arg
.
p_dbias_
);
arg
.
p_dbias_
);
...
@@ -800,6 +814,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -800,6 +814,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const
void
*
p_savedMean
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
const
void
*
p_savedInvVar
,
double
epsilon
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dscale
,
void
*
p_dbias
)
override
void
*
p_dbias
)
override
...
@@ -818,6 +833,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
...
@@ -818,6 +833,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
dy_elementwise_op
,
epsilon
,
epsilon
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
ScaleDataType
*>
(
p_dscale
),
static_cast
<
ScaleDataType
*>
(
p_dscale
),
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
View file @
59613285
...
@@ -18,6 +18,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
...
@@ -18,6 +18,7 @@ template <typename GridwiseReduceSecondHalfBatchNormBackwardFinal_,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
XYGridDesc_M_K
,
typename
DscaleDbiasGridDesc_M_K
,
typename
DscaleDbiasGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
MeanVarGridDesc_M
,
...
@@ -41,6 +42,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
...
@@ -41,6 +42,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const
XDataType
*
const
__restrict__
p_x
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_dscale
,
ScaleDataType
*
const
__restrict__
p_dscale
,
BiasDataType
*
const
__restrict__
p_dbias
)
BiasDataType
*
const
__restrict__
p_dbias
)
...
@@ -63,6 +65,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
...
@@ -63,6 +65,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
p_x
,
p_x
,
p_dy
,
p_dy
,
p_scale
,
p_scale
,
dy_elementwise_op
,
p_dx
,
p_dx
,
p_dscale
,
p_dscale
,
p_dbias
);
p_dbias
);
...
@@ -75,6 +78,7 @@ template <typename XDataType,
...
@@ -75,6 +78,7 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
XYGridDesc_M_K
,
typename
DscaleDbiasGridDesc_M_K
,
typename
DscaleDbiasGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
MeanVarGridDesc_M
,
...
@@ -163,6 +167,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
...
@@ -163,6 +167,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
const
XDataType
*
const
__restrict__
p_x
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_dscale
,
ScaleDataType
*
const
__restrict__
p_dscale
,
BiasDataType
*
const
__restrict__
p_dbias
)
BiasDataType
*
const
__restrict__
p_dbias
)
...
@@ -498,6 +503,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
...
@@ -498,6 +503,9 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
constexpr
auto
offset
=
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
dy_elementwise_op
(
dy_thread_buf
(
Number
<
offset
>
{}),
dy_thread_buf
[
Number
<
offset
>
{}]);
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
inv_var_thread_buf
[
iM
];
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
View file @
59613285
...
@@ -19,6 +19,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
...
@@ -19,6 +19,7 @@ template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
XYGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
MeanVarCountGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_K
,
...
@@ -39,6 +40,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
...
@@ -39,6 +40,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
const
DyElementwiseOp
dy_elementwise_op
,
MeanVarDataType
*
const
__restrict__
p_out_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_out_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
XDataType
*
const
__restrict__
p_x
,
...
@@ -61,6 +63,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
...
@@ -61,6 +63,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
p_in_welford_mean
,
p_in_welford_mean
,
p_in_welford_variance
,
p_in_welford_variance
,
p_in_welford_count
,
p_in_welford_count
,
dy_elementwise_op
,
p_out_welford_mean
,
p_out_welford_mean
,
p_out_welford_inv_variance
,
p_out_welford_inv_variance
,
p_x
,
p_x
,
...
@@ -75,6 +78,7 @@ template <typename XDataType,
...
@@ -75,6 +78,7 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
XYGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
MeanVarCountGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_K
,
...
@@ -165,6 +169,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
...
@@ -165,6 +169,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
const
DyElementwiseOp
dy_elementwise_op
,
MeanVarDataType
*
const
__restrict__
p_out_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_out_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
XDataType
*
const
__restrict__
p_x
,
...
@@ -480,6 +485,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
...
@@ -480,6 +485,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
constexpr
auto
offset
=
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
dy_elementwise_op
(
dy_thread_buf
(
Number
<
offset
>
{}),
dy_thread_buf
[
Number
<
offset
>
{}]);
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
inv_var_thread_buf
[
iM
];
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
View file @
59613285
...
@@ -23,6 +23,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
...
@@ -23,6 +23,7 @@ template <typename GridwiseBatchrNormBackwardWithBlockwiseWelford_,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
MeanVarGridDesc_M
,
...
@@ -44,6 +45,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
...
@@ -44,6 +45,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
bool
haveSavedMeanInvVar
,
bool
haveSavedMeanInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_dscale
,
ScaleDataType
*
const
__restrict__
p_dscale
,
BiasDataType
*
const
__restrict__
p_dbias
)
BiasDataType
*
const
__restrict__
p_dbias
)
...
@@ -64,6 +66,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
...
@@ -64,6 +66,7 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
haveSavedMeanInvVar
,
haveSavedMeanInvVar
,
p_savedMean
,
p_savedMean
,
p_savedInvVar
,
p_savedInvVar
,
dy_elementwise_op
,
p_dx
,
p_dx
,
p_dscale
,
p_dscale
,
p_dbias
);
p_dbias
);
...
@@ -76,6 +79,7 @@ template <typename XDataType,
...
@@ -76,6 +79,7 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
MeanVarGridDesc_M
,
...
@@ -173,6 +177,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
...
@@ -173,6 +177,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
bool
haveSavedMeanInvVar
,
bool
haveSavedMeanInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_dscale
,
ScaleDataType
*
const
__restrict__
p_dscale
,
BiasDataType
*
const
__restrict__
p_dbias
)
BiasDataType
*
const
__restrict__
p_dbias
)
...
@@ -455,6 +460,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
...
@@ -455,6 +460,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr
auto
offset
=
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
dy_elementwise_op
(
dy_thread_buf
(
Number
<
offset
>
{}),
dy_thread_buf
[
Number
<
offset
>
{}]);
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
inv_var_thread_buf
[
iM
];
...
@@ -531,6 +539,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
...
@@ -531,6 +539,9 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr
auto
offset
=
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
dy_elementwise_op
(
dy_thread_buf
(
Number
<
offset
>
{}),
dy_thread_buf
[
Number
<
offset
>
{}]);
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
inv_var_thread_buf
[
iM
];
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward_nhwc_c.hpp
View file @
59613285
...
@@ -19,8 +19,10 @@ template <typename XDataType,
...
@@ -19,8 +19,10 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
>
typename
MeanVarDataType
,
struct
ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormBwd
<
4
,
3
>
typename
DyElementwiseOp
>
struct
ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormBwd
<
4
,
3
,
DyElementwiseOp
>
{
{
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
...
@@ -39,6 +41,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -39,6 +41,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
const
MeanVarDataType
*
p_savedInvVar
,
double
epsilon
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
p_dx
,
DxDataType
*
p_dx
,
ScaleDataType
*
p_dscale
,
ScaleDataType
*
p_dscale
,
BiasDataType
*
p_dbias
)
BiasDataType
*
p_dbias
)
...
@@ -48,6 +51,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -48,6 +51,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
p_savedMean_
(
p_savedMean
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
p_savedInvVar_
(
p_savedInvVar
),
epsilon_
(
epsilon
),
epsilon_
(
epsilon
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
p_dbias_
(
p_dbias
)
...
@@ -79,6 +83,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -79,6 +83,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const
MeanVarDataType
*
p_savedInvVar_
;
const
MeanVarDataType
*
p_savedInvVar_
;
double
epsilon_
;
double
epsilon_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
DxDataType
*
p_dx_
;
ScaleDataType
*
p_dscale_
;
ScaleDataType
*
p_dscale_
;
...
@@ -165,6 +170,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -165,6 +170,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
offset
]);
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
dbias
+=
dy
;
dbias
+=
dy
;
dscale
+=
norm_x
*
dy
;
dscale
+=
norm_x
*
dy
;
};
};
...
@@ -194,6 +201,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -194,6 +201,8 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
offset
]);
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
offset
]);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
p_scale_
[
offset_C
]);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
p_scale_
[
offset_C
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
AccDataType
tmpVal
=
norm_x
*
dscale
;
AccDataType
tmpVal
=
norm_x
*
dscale
;
AccDataType
dx
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
reduceSize
*
invVar
*
AccDataType
dx
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
reduceSize
*
invVar
*
...
@@ -258,6 +267,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -258,6 +267,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
const
void
*
p_savedMean
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
const
void
*
p_savedInvVar
,
double
epsilon
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dscale
,
void
*
p_dbias
)
override
void
*
p_dbias
)
override
...
@@ -277,6 +287,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -277,6 +287,7 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
epsilon
,
epsilon
,
dy_elementwise_op
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
ScaleDataType
*>
(
p_dscale
),
static_cast
<
ScaleDataType
*>
(
p_dscale
),
static_cast
<
BiasDataType
*>
(
p_dbias
));
static_cast
<
BiasDataType
*>
(
p_dbias
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment