Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3f392b53
Commit
3f392b53
authored
Oct 31, 2022
by
Qianfeng Zhang
Browse files
Parameters renaming in batchnorm backward kernels and device op
parent
d0b49a14
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
58 additions
and
58 deletions
+58
-58
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
...tensor_operation/gpu/device/device_batchnorm_backward.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
...ration/gpu/device/impl/device_batchnorm_backward_impl.hpp
+24
-24
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+16
-16
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
...lock_welford_second_half_multiblock_reduce_first_half.hpp
+8
-8
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
...pu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
+8
-8
No files found.
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
View file @
3f392b53
...
...
@@ -35,8 +35,8 @@ struct DeviceBatchNormBwd : public BaseOperator
const
void
*
p_savedInvVar
,
double
epsilon
,
void
*
p_dx
,
void
*
p_scale
Diff
,
void
*
p_bias
Diff
)
=
0
;
void
*
p_
d
scale
,
void
*
p_
d
bias
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
View file @
3f392b53
...
...
@@ -201,8 +201,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const
MeanVarDataType
*
p_savedInvVar
,
double
epsilon
,
DxDataType
*
p_dx
,
ScaleDataType
*
p_scale
Diff
,
BiasDataType
*
p_bias
Diff
)
ScaleDataType
*
p_
d
scale
,
BiasDataType
*
p_
d
bias
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
...
...
@@ -213,8 +213,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
p_dx_
(
p_dx
),
p_scale
Diff
_
(
p_scale
Diff
),
p_bias
Diff
_
(
p_bias
Diff
)
p_
d
scale_
(
p_
d
scale
),
p_
d
bias_
(
p_
d
bias
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
...
...
@@ -294,8 +294,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
DxDataType
*
p_dx_
;
ScaleDataType
*
p_scale
Diff
_
;
BiasDataType
*
p_bias
Diff
_
;
ScaleDataType
*
p_
d
scale_
;
BiasDataType
*
p_
d
bias_
;
long_index_t
invariant_length
;
long_index_t
reduce_length
;
...
...
@@ -318,8 +318,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
void
*
workspace_savedMean
;
void
*
workspace_savedInvVar
;
void
*
workspace_reduce_scale
_diff
;
void
*
workspace_reduce_bias
_diff
;
void
*
workspace_reduce_
d
scale
;
void
*
workspace_reduce_
d
bias
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
...
...
@@ -372,14 +372,14 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
index_t
space_sz
;
// setup buffer for the partial reduced result for scale_diff
pArg_
->
workspace_reduce_scale
_diff
=
pArg_
->
p_workspace_
;
pArg_
->
workspace_reduce_
d
scale
=
pArg_
->
p_workspace_
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
ScaleDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for the partial reduced result for bias_diff
pArg_
->
workspace_reduce_bias
_diff
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_scale
_diff
)
+
space_sz
;
pArg_
->
workspace_reduce_
d
bias
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_
d
scale
)
+
space_sz
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize
>
1
)
{
...
...
@@ -388,7 +388,7 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
// setup buffer for welford intermediate mean
pArg_
->
workspace_mean
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_bias
_diff
)
+
space_sz
;
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_
d
bias
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
...
...
@@ -604,8 +604,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedInvVar
),
arg
.
p_x_
,
arg
.
p_dy_
,
static_cast
<
ScaleDataType
*>
(
arg
.
workspace_reduce_scale
_diff
),
static_cast
<
BiasDataType
*>
(
arg
.
workspace_reduce_bias
_diff
));
static_cast
<
ScaleDataType
*>
(
arg
.
workspace_reduce_
d
scale
),
static_cast
<
BiasDataType
*>
(
arg
.
workspace_reduce_
d
bias
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -624,8 +624,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg
.
reduce_length
,
arg
.
numBlockTileIteration
,
numScaleBiasDiffBlockTileIteration
,
static_cast
<
const
ScaleDataType
*>
(
arg
.
workspace_reduce_scale
_diff
),
static_cast
<
const
BiasDataType
*>
(
arg
.
workspace_reduce_bias
_diff
),
static_cast
<
const
ScaleDataType
*>
(
arg
.
workspace_reduce_
d
scale
),
static_cast
<
const
BiasDataType
*>
(
arg
.
workspace_reduce_
d
bias
),
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedMean_
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
...
...
@@ -636,8 +636,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg
.
p_dy_
,
arg
.
p_scale_
,
arg
.
p_dx_
,
arg
.
p_scale
Diff
_
,
arg
.
p_bias
Diff
_
);
arg
.
p_
d
scale_
,
arg
.
p_
d
bias_
);
}
else
{
...
...
@@ -708,8 +708,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
arg
.
p_savedMean_
,
arg
.
p_savedInvVar_
,
arg
.
p_dx_
,
arg
.
p_scale
Diff
_
,
arg
.
p_bias
Diff
_
);
arg
.
p_
d
scale_
,
arg
.
p_
d
bias_
);
};
return
(
avg_time
);
...
...
@@ -801,8 +801,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
const
void
*
p_savedInvVar
,
double
epsilon
,
void
*
p_dx
,
void
*
p_scale
Diff
,
void
*
p_bias
Diff
)
override
void
*
p_
d
scale
,
void
*
p_
d
bias
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
...
...
@@ -820,8 +820,8 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<Rank, NumBatchNormRedu
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
epsilon
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
ScaleDataType
*>
(
p_scale
Diff
),
static_cast
<
BiasDataType
*>
(
p_bias
Diff
));
static_cast
<
ScaleDataType
*>
(
p_
d
scale
),
static_cast
<
BiasDataType
*>
(
p_
d
bias
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
View file @
3f392b53
...
...
@@ -34,16 +34,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_scale_bias_diff_k_block_tile_iteration
,
const
ScaleDataType
*
const
__restrict__
p_reduce_scale
_diff
,
const
BiasDataType
*
const
__restrict__
p_reduce_bias
_diff
,
const
ScaleDataType
*
const
__restrict__
p_reduce_
d
scale
,
const
BiasDataType
*
const
__restrict__
p_reduce_
d
bias
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale
_diff
,
BiasDataType
*
const
__restrict__
p_bias
_diff
)
ScaleDataType
*
const
__restrict__
p_
d
scale
,
BiasDataType
*
const
__restrict__
p_
d
bias
)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
...
...
@@ -56,16 +56,16 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
reduce_size
,
num_xy_k_block_tile_iteration
,
num_scale_bias_diff_k_block_tile_iteration
,
p_reduce_scale
_diff
,
p_reduce_bias
_diff
,
p_reduce_
d
scale
,
p_reduce_
d
bias
,
p_mean
,
p_inv_var
,
p_x
,
p_dy
,
p_scale
,
p_dx
,
p_scale
_diff
,
p_bias
_diff
);
p_
d
scale
,
p_
d
bias
);
};
template
<
typename
XDataType
,
...
...
@@ -151,16 +151,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_scale_bias_diff_k_block_tile_iteration
,
const
ScaleDataType
*
const
__restrict__
p_reduce_scale
_diff
,
const
BiasDataType
*
const
__restrict__
p_reduce_bias
_diff
,
const
ScaleDataType
*
const
__restrict__
p_reduce_
d
scale
,
const
BiasDataType
*
const
__restrict__
p_reduce_
d
bias
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale
_diff
,
BiasDataType
*
const
__restrict__
p_bias
_diff
)
ScaleDataType
*
const
__restrict__
p_
d
scale
,
BiasDataType
*
const
__restrict__
p_
d
bias
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
...
@@ -281,16 +281,16 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
PassThroughOp
{});
const
auto
reduce_scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_scale
_diff
,
scale_bias_diff_grid_desc_m_k
.
GetElementSpaceSize
());
p_reduce_
d
scale
,
scale_bias_diff_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
reduce_bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_bias
_diff
,
scale_bias_diff_grid_desc_m_k
.
GetElementSpaceSize
());
p_reduce_
d
bias
,
scale_bias_diff_grid_desc_m_k
.
GetElementSpaceSize
());
auto
scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
_diff
,
scale_grid_desc_m
.
GetElementSpaceSize
());
p_
d
scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
_diff
,
bias_grid_desc_m
.
GetElementSpaceSize
());
p_
d
bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
constexpr
auto
scale_bias_diff_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
View file @
3f392b53
...
...
@@ -43,8 +43,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
ScaleDataType
*
const
__restrict__
p_reduce_scale
_diff
,
BiasDataType
*
const
__restrict__
p_reduce_bias
_diff
)
ScaleDataType
*
const
__restrict__
p_reduce_
d
scale
,
BiasDataType
*
const
__restrict__
p_reduce_
d
bias
)
{
GridwiseWelfordSecondHalfReduceFirstHalf_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
...
...
@@ -65,8 +65,8 @@ __global__ void kernel_welford_second_half_reduce_first_half(
p_out_welford_inv_variance
,
p_x
,
p_dy
,
p_reduce_scale
_diff
,
p_reduce_bias
_diff
);
p_reduce_
d
scale
,
p_reduce_
d
bias
);
};
template
<
typename
XDataType
,
...
...
@@ -164,8 +164,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
ScaleDataType
*
const
__restrict__
p_reduce_scale
_diff
,
BiasDataType
*
const
__restrict__
p_reduce_bias
_diff
)
ScaleDataType
*
const
__restrict__
p_reduce_
d
scale
,
BiasDataType
*
const
__restrict__
p_reduce_
d
bias
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
...
...
@@ -531,10 +531,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
PassThroughOp
{});
auto
reduce_scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_scale
_diff
,
scale_bias_diff_grid_desc_m_g
.
GetElementSpaceSize
());
p_reduce_
d
scale
,
scale_bias_diff_grid_desc_m_g
.
GetElementSpaceSize
());
auto
reduce_bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_bias
_diff
,
scale_bias_diff_grid_desc_m_g
.
GetElementSpaceSize
());
p_reduce_
d
bias
,
scale_bias_diff_grid_desc_m_g
.
GetElementSpaceSize
());
if
(
thread_k_cluster_id
==
0
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
View file @
3f392b53
...
...
@@ -45,8 +45,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale
_diff
,
BiasDataType
*
const
__restrict__
p_bias
_diff
)
ScaleDataType
*
const
__restrict__
p_
d
scale
,
BiasDataType
*
const
__restrict__
p_
d
bias
)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
...
...
@@ -65,8 +65,8 @@ __global__ void kernel_batchnorm_backward_with_blockwise_welford(
p_savedMean
,
p_savedInvVar
,
p_dx
,
p_scale
_diff
,
p_bias
_diff
);
p_
d
scale
,
p_
d
bias
);
};
template
<
typename
XDataType
,
...
...
@@ -166,8 +166,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale
_diff
,
BiasDataType
*
const
__restrict__
p_bias
_diff
)
ScaleDataType
*
const
__restrict__
p_
d
scale
,
BiasDataType
*
const
__restrict__
p_
d
bias
)
{
using
ck
::
math
::
sqrt
;
...
...
@@ -333,10 +333,10 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
_diff
,
scale_grid_desc_m
.
GetElementSpaceSize
());
p_
d
scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
_diff
,
bias_grid_desc_m
.
GetElementSpaceSize
());
p_
d
bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
if
(
haveSavedMeanInvVar
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment