Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de6aad06
Commit
de6aad06
authored
Oct 31, 2022
by
Qianfeng Zhang
Browse files
Parameters renaming again in batchnorm backward kernels
parent
7d114e80
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
155 additions
and
159 deletions
+155
-159
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+67
-69
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
...lock_welford_second_half_multiblock_reduce_first_half.hpp
+48
-48
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
...pu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
+40
-42
No files found.
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
View file @
de6aad06
...
...
@@ -26,7 +26,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_K
scale_bias_
diff_
grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_K
d
scale_
d
bias_grid_desc_m_k
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
...
...
@@ -48,7 +48,7 @@ __global__ void kernel_reduce_second_half_batchnorm_backward_final(
GridwiseReduceSecondHalfBatchNormBackwardFinal_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
dx_grid_desc_m_k
,
scale_bias_
diff_
grid_desc_m_k
,
d
scale_
d
bias_grid_desc_m_k
,
mean_var_grid_desc_m
,
scale_grid_desc_m
,
bias_grid_desc_m
,
...
...
@@ -143,7 +143,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dx_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_K
&
scale_bias_
diff_
grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_K
&
d
scale_
d
bias_grid_desc_m_k
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
...
...
@@ -168,14 +168,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
reduce_scale_
diff_
thread_buf
;
reduce_
d
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
reduce_bias_
diff_
thread_buf
;
reduce_
d
bias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
dscale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
dbias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
...
...
@@ -212,7 +210,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
// Step 1: do final reduction for scale_diff and bias_diff and output
auto
threadwise_scale_
diff_
load_m_k
=
auto
threadwise_
d
scale_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasDiffGridDesc_M_K
,
...
...
@@ -223,12 +221,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
1
,
1
,
true
>
(
scale_bias_
diff_
grid_desc_m_k
,
d
scale_
d
bias_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_bias_
diff_
load_m_k
=
auto
threadwise_
d
bias_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasDiffGridDesc_M_K
,
...
...
@@ -239,12 +237,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
1
,
1
,
true
>
(
scale_bias_
diff_
grid_desc_m_k
,
d
scale_
d
bias_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_scale_
diff_
store_m
=
auto
threadwise_
d
scale_store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m
),
...
...
@@ -262,7 +260,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_bias_
diff_
store_m
=
auto
threadwise_
d
bias_store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m
),
...
...
@@ -280,67 +278,67 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
const
auto
reduce_scale_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dscale
,
scale_bias_
diff_
grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
reduce_
d
scale_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dscale
,
d
scale_
d
bias_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
reduce_bias_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dbias
,
scale_bias_
diff_
grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
reduce_
d
bias_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dbias
,
d
scale_
d
bias_grid_desc_m_k
.
GetElementSpaceSize
());
auto
scale_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d
scale_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dscale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
bias_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d
bias_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
constexpr
auto
scale_bias_
diff_
thread_copy_step_m_k
=
constexpr
auto
d
scale_
d
bias_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
scale_
diff_
thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
bias_
diff_
thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
d
scale_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
d
bias_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_scale_bias_diff_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_scale_
diff_
load_m_k
.
Run
(
scale_bias_
diff_
grid_desc_m_k
,
reduce_scale_
diff_global_v
al_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_scale_
diff_
thread_buf
);
threadwise_bias_
diff_
load_m_k
.
Run
(
scale_bias_
diff_
grid_desc_m_k
,
reduce_bias_
diff_global_v
al_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_bias_
diff_
thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_scale_
diff_
thread_buf
,
scale_
diff_
thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_bias_
diff_
thread_buf
,
bias_
diff_
thread_buf
);
threadwise_scale_
diff_
load_m_k
.
MoveSrcSliceWindow
(
scale_bias_
diff_
grid_desc_m_k
,
scale_bias_
diff_
thread_copy_step_m_k
);
threadwise_bias_
diff_
load_m_k
.
MoveSrcSliceWindow
(
scale_bias_
diff_
grid_desc_m_k
,
scale_bias_
diff_
thread_copy_step_m_k
);
threadwise_
d
scale_load_m_k
.
Run
(
d
scale_
d
bias_grid_desc_m_k
,
reduce_
d
scale_
glob
al_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_
d
scale_thread_buf
);
threadwise_
d
bias_load_m_k
.
Run
(
d
scale_
d
bias_grid_desc_m_k
,
reduce_
d
bias_
glob
al_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_
d
bias_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_
d
scale_thread_buf
,
d
scale_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_
d
bias_thread_buf
,
d
bias_thread_buf
);
threadwise_
d
scale_load_m_k
.
MoveSrcSliceWindow
(
d
scale_
d
bias_grid_desc_m_k
,
d
scale_
d
bias_thread_copy_step_m_k
);
threadwise_
d
bias_load_m_k
.
MoveSrcSliceWindow
(
d
scale_
d
bias_grid_desc_m_k
,
d
scale_
d
bias_thread_copy_step_m_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
scale_
diff_
thread_buf
(
I
));
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
d
scale_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
bias_
diff_
thread_buf
(
I
));
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
d
bias_thread_buf
(
I
));
});
threadwise_scale_
diff_
store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_
diff_
thread_buf
,
scale_grid_desc_m
,
scale_
diff_global_v
al_buf
);
threadwise_
d
scale_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
d
scale_thread_buf
,
scale_grid_desc_m
,
d
scale_
glob
al_buf
);
threadwise_bias_
diff_
store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_diff
_thread_buf
,
bias_grid_desc_m
,
bias_
diff_global_v
al_buf
);
threadwise_
d
bias_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias
_thread_buf
,
bias_grid_desc_m
,
d
bias_
glob
al_buf
);
// Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean)
// * invVar) and output
...
...
@@ -426,38 +424,38 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
x_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
x_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
dy_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
dx_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
scale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
mean_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_var
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_
val_
buf
,
scale_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_
val_
buf
,
mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_
val_
buf
,
inv_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
...
...
@@ -467,13 +465,13 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_xy_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_
val_
buf
,
x_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_
val_
buf
,
dy_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
...
...
@@ -490,12 +488,12 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
scale_
diff_
thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
d
scale_thread_buf
[
iM
];
dx_thread_buf
(
Number
<
offset
>
{})
=
multiplier
*
(
type_convert
<
AccDataType
>
(
reduce_size
)
*
dy_thread_buf
[
Number
<
offset
>
{}]
-
bias_
diff_
thread_buf
[
iM
]
-
tmpVal
);
d
bias_thread_buf
[
iM
]
-
tmpVal
);
});
});
...
...
@@ -503,7 +501,7 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_
val_
buf
);
dx_global_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
View file @
de6aad06
...
...
@@ -28,7 +28,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
MeanVarCountGridDesc_M_K
mean_var_count_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_G
scale_bias_grid_desc_m_g
,
const
ScaleBiasDiffGridDesc_M_G
d
scale_
d
bias_grid_desc_m_g
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
...
...
@@ -50,7 +50,7 @@ __global__ void kernel_welford_second_half_reduce_first_half(
dy_grid_desc_m_k
,
mean_var_grid_desc_m
,
mean_var_count_grid_desc_m_k
,
scale_bias_grid_desc_m_g
,
d
scale_
d
bias_grid_desc_m_g
,
blkgroup_size
,
num_xy_k_block_tile_iteration
,
num_mean_var_count_k_block_tile_iteration
,
...
...
@@ -149,7 +149,7 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
const
XYGridDesc_M_K
&
dy_grid_desc_m_k
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
MeanVarCountGridDesc_M_K
&
mean_var_count_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_G
&
scale_bias_
diff_
grid_desc_m_g
,
const
ScaleBiasDiffGridDesc_M_G
&
d
scale_
d
bias_grid_desc_m_g
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
...
...
@@ -201,9 +201,9 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
tmp1_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
reduce_scale_
diff_
thread_buf
;
reduce_
d
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
reduce_bias_
diff_
thread_buf
;
reduce_
d
bias_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
...
@@ -231,10 +231,10 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
if
(
haveSavedMeanInvVar
)
{
const
auto
mean_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedInvVar
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_inv_var_load
=
...
...
@@ -253,26 +253,26 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_
val_
buf
,
mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_
val_
buf
,
inv_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
}
else
{
const
auto
welford_mean_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
welford_var_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_variance
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
welford_count_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_count_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load_m_k
=
...
...
@@ -320,19 +320,19 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
++
reducedTiles
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_mean_global_
val_
buf
,
welford_mean_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_var_global_
val_
buf
,
welford_var_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_count_global_
val_
buf
,
welford_count_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
...
...
@@ -386,23 +386,23 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
mean_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_welford_mean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
inv_var_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_welford_inv_variance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
mean_global_
val_
buf
);
mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
,
mean_var_grid_desc_m
,
inv_var_global_
val_
buf
);
inv_var_global_buf
);
};
};
...
...
@@ -438,17 +438,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
const
auto
x_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
x_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
dy_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
constexpr
auto
xy_thread_copy_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
reduce_scale_
diff_
thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
reduce_bias_
diff_
thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
reduce_
d
scale_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
reduce_
d
bias_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
});
// Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance
...
...
@@ -456,13 +456,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_xy_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_
val_
buf
,
x_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_
val_
buf
,
dy_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
...
...
@@ -479,20 +479,20 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
});
});
ThreadwiseReduce
::
Reduce
(
tmp1_thread_buf
,
reduce_scale_
diff_
thread_buf
);
ThreadwiseReduce
::
Reduce
(
dy_thread_buf
,
reduce_bias_
diff_
thread_buf
);
ThreadwiseReduce
::
Reduce
(
tmp1_thread_buf
,
reduce_
d
scale_thread_buf
);
ThreadwiseReduce
::
Reduce
(
dy_thread_buf
,
reduce_
d
bias_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_scale_
diff_
thread_buf
(
I
));
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_
d
scale_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_bias_
diff_
thread_buf
(
I
));
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_
d
bias_thread_buf
(
I
));
});
auto
threadwise_scale_
diff_
store
=
auto
threadwise_
d
scale_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m_1
),
...
...
@@ -505,13 +505,13 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_bias_
diff_
grid_desc_m_g
,
d
scale_
d
bias_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_bias_
diff_
store
=
auto
threadwise_
d
bias_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m_1
),
...
...
@@ -524,31 +524,31 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_bias_
diff_
grid_desc_m_g
,
d
scale_
d
bias_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
reduce_scale_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dscale
,
scale_bias_
diff_
grid_desc_m_g
.
GetElementSpaceSize
());
auto
reduce_
d
scale_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dscale
,
d
scale_
d
bias_grid_desc_m_g
.
GetElementSpaceSize
());
auto
reduce_bias_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dbias
,
scale_bias_
diff_
grid_desc_m_g
.
GetElementSpaceSize
());
auto
reduce_
d
bias_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dbias
,
d
scale_
d
bias_grid_desc_m_g
.
GetElementSpaceSize
());
if
(
thread_k_cluster_id
==
0
)
{
threadwise_scale_
diff_
store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_scale_
diff_
thread_buf
,
scale_bias_
diff_
grid_desc_m_g
,
reduce_scale_
diff_global_v
al_buf
);
threadwise_bias_
diff_
store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_bias_
diff_
thread_buf
,
scale_bias_
diff_
grid_desc_m_g
,
reduce_bias_
diff_global_v
al_buf
);
threadwise_
d
scale_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_
d
scale_thread_buf
,
d
scale_
d
bias_grid_desc_m_g
,
reduce_
d
scale_
glob
al_buf
);
threadwise_
d
bias_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_
d
bias_thread_buf
,
d
scale_
d
bias_grid_desc_m_g
,
reduce_
d
bias_
glob
al_buf
);
};
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
View file @
de6aad06
...
...
@@ -204,10 +204,8 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
inv_var_thread_buf
=
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
dscale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
dbias_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
...
@@ -289,7 +287,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_scale_
diff_
store
=
auto
threadwise_
d
scale_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m
),
...
...
@@ -307,7 +305,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_bias_
diff_
store
=
auto
threadwise_
d
bias_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m
),
...
...
@@ -328,30 +326,30 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
const
auto
x_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
x_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
dy_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
dx_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
scale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
scale_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d
scale_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dscale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
bias_
diff_global_v
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
d
bias_
glob
al_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
if
(
haveSavedMeanInvVar
)
{
const
auto
mean_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_
val_
buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedInvVar
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_inv_var_load
=
...
...
@@ -370,13 +368,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_
val_
buf
,
mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_
val_
buf
,
inv_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
...
...
@@ -395,7 +393,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_
val_
buf
,
x_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
...
...
@@ -425,20 +423,20 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
scale_
diff_
thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
bias_
diff_
thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
d
scale_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
d
bias_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_
val_
buf
,
x_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dx_grid_desc_m_k
,
dy_global_
val_
buf
,
dy_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
...
...
@@ -455,36 +453,36 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
});
});
ThreadwiseReduce
::
Reduce
(
tmp1_thread_buf
,
scale_
diff_
thread_buf
);
ThreadwiseReduce
::
Reduce
(
dy_thread_buf
,
bias_
diff_
thread_buf
);
ThreadwiseReduce
::
Reduce
(
tmp1_thread_buf
,
d
scale_thread_buf
);
ThreadwiseReduce
::
Reduce
(
dy_thread_buf
,
d
bias_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
scale_
diff_
thread_buf
(
I
));
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
d
scale_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
bias_
diff_
thread_buf
(
I
));
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
d
bias_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_scale_
diff_
store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_
diff_
thread_buf
,
scale_grid_desc_m
,
scale_
diff_global_v
al_buf
);
threadwise_bias_
diff_
store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_diff
_thread_buf
,
bias_grid_desc_m
,
bias_
diff_global_v
al_buf
);
threadwise_
d
scale_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
d
scale_thread_buf
,
scale_grid_desc_m
,
d
scale_
glob
al_buf
);
threadwise_
d
bias_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias
_thread_buf
,
bias_grid_desc_m
,
d
bias_
glob
al_buf
);
};
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_
val_
buf
,
scale_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
...
...
@@ -498,13 +496,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_
val_
buf
,
x_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_
val_
buf
,
dy_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
...
...
@@ -517,13 +515,13 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
scale_
diff_
thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
d
scale_thread_buf
[
iM
];
dx_thread_buf
(
Number
<
offset
>
{})
=
type_convert
<
AccDataType
>
(
1.0
)
/
type_convert
<
AccDataType
>
(
reduce_size
)
*
inv_var_thread_buf
[
iM
]
*
scale_thread_buf
[
iM
]
*
(
type_convert
<
AccDataType
>
(
reduce_size
)
*
dy_thread_buf
[
Number
<
offset
>
{}]
-
bias_
diff_
thread_buf
[
iM
]
-
tmpVal
);
d
bias_thread_buf
[
iM
]
-
tmpVal
);
});
});
...
...
@@ -531,7 +529,7 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_
val_
buf
);
dx_global_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment