Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bbd19d89
"vscode:/vscode.git/clone" did not exist on "952683573f0e748dae25d8d6a6f6ae1d14b1309d"
Commit
bbd19d89
authored
Sep 29, 2022
by
Qianfeng Zhang
Browse files
Implemented batchnorm-backward Blockwise and Multiblock kernels
parent
b8825547
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1915 additions
and
0 deletions
+1915
-0
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
...pu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
+535
-0
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+515
-0
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp
...ation/gpu/grid/gridwise_multiblock_welford_first_half.hpp
+258
-0
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
...lock_welford_second_half_multiblock_reduce_first_half.hpp
+556
-0
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
...ude/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
+51
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp
0 → 100644
View file @
bbd19d89
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseBatchrNormBackwardWithBlockwiseWelford_
,
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_batchnorm_backward_with_blockwise_welford
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
long_index_t
reduce_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
bool
haveSavedMeanInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale_diff
,
BiasDataType
*
const
__restrict__
p_bias_diff
)
{
GridwiseBatchrNormBackwardWithBlockwiseWelford_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
dx_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
reduce_size
,
num_k_block_tile_iteration
,
epsilon
,
p_x
,
p_dy
,
p_scale
,
haveSavedMeanInvVar
,
p_savedMean
,
p_savedInvVar
,
p_dx
,
p_scale_diff
,
p_bias_diff
);
};
template
<
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XDyDxVectorDim
,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrcDstVectorSize
,
index_t
BiasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
GridwiseBatchNormBackwardWithBlockwiseWelford
{
static_assert
((
XDyDxVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
&&
MThreadSliceSize
%
DySrcVectorSize
==
0
&&
MThreadSliceSize
%
DxDstVectorSize
==
0
)
||
(
XDyDxVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
&&
KThreadSliceSize
%
DySrcVectorSize
==
0
&&
KThreadSliceSize
%
DxDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XDyDxVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ck
::
reduce
::
Add
,
false
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ck
::
reduce
::
Add
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
long_index_t
reduce_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
bool
haveSavedMeanInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale_diff
,
BiasDataType
*
const
__restrict__
p_bias_diff
)
{
using
ck
::
math
::
sqrt
;
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dy_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dx_thread_buf
;
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
tmp1_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
inv_var_thread_buf
=
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_diff_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DyDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
DxDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
DxDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dy_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcDstVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_scale_diff_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_bias_diff_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale_diff
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias_diff
,
bias_grid_desc_m
.
GetElementSpaceSize
());
if
(
haveSavedMeanInvVar
)
{
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedInvVar
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_inv_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
}
else
{
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
get_reduce_count_per_thread
(
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
inv_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
)
/
sqrt
(
var_thread_buf
[
I
]
+
epsilon
);
});
threadwise_x_load
.
SetSrcSliceOrigin
(
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
scale_diff_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
bias_diff_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dx_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
tmp1_thread_buf
(
Number
<
offset
>
{})
=
norm_x
*
dy_thread_buf
[
Number
<
offset
>
{}];
});
});
ThreadwiseReduce
::
Reduce
(
tmp1_thread_buf
,
scale_diff_thread_buf
);
ThreadwiseReduce
::
Reduce
(
dy_thread_buf
,
bias_diff_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
scale_diff_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
bias_diff_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_scale_diff_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_diff_thread_buf
,
scale_grid_desc_m
,
scale_diff_global_val_buf
);
threadwise_bias_diff_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_diff_thread_buf
,
bias_grid_desc_m
,
bias_diff_global_val_buf
);
};
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
scale_diff_thread_buf
[
iM
];
dx_thread_buf
(
Number
<
offset
>
{})
=
type_convert
<
AccDataType
>
(
1.0
)
/
type_convert
<
AccDataType
>
(
reduce_size
)
*
inv_var_thread_buf
[
iM
]
*
scale_thread_buf
[
iM
]
*
(
type_convert
<
AccDataType
>
(
reduce_size
)
*
dy_thread_buf
[
Number
<
offset
>
{}]
-
bias_diff_thread_buf
[
iM
]
-
tmpVal
);
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
0 → 100644
View file @
bbd19d89
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduceSecondHalfBatchNormBackwardFinal_
,
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasDiffGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
ScaleBiasGridDesc_M
>
__global__
void
kernel_reduce_second_half_batchnorm_backward_final
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_K
scale_bias_diff_grid_desc_m_k
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
index_t
blkgroup_size
,
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_scale_bias_diff_k_block_tile_iteration
,
const
ScaleDataType
*
const
__restrict__
p_reduce_scale_diff
,
const
BiasDataType
*
const
__restrict__
p_reduce_bias_diff
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale_diff
,
BiasDataType
*
const
__restrict__
p_bias_diff
)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
dx_grid_desc_m_k
,
scale_bias_diff_grid_desc_m_k
,
mean_var_grid_desc_m
,
scale_grid_desc_m
,
bias_grid_desc_m
,
blkgroup_size
,
reduce_size
,
num_xy_k_block_tile_iteration
,
num_scale_bias_diff_k_block_tile_iteration
,
p_reduce_scale_diff
,
p_reduce_bias_diff
,
p_mean
,
p_inv_var
,
p_x
,
p_dy
,
p_scale
,
p_dx
,
p_scale_diff
,
p_bias_diff
);
};
template
<
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
XYGridDesc_M_K
,
typename
ScaleBiasDiffGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
ScaleBiasGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XDyDxVectorDim
,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrcDstVectorSize
,
index_t
BiasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
GridwiseReduceSecondHalfBatchNormBackwardFinal
{
static_assert
((
XDyDxVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
&&
MThreadSliceSize
%
DySrcVectorSize
==
0
&&
MThreadSliceSize
%
DxDstVectorSize
==
0
)
||
(
XDyDxVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
&&
KThreadSliceSize
%
DySrcVectorSize
==
0
&&
KThreadSliceSize
%
DxDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XDyDxVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ck
::
reduce
::
Add
,
false
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
,
ck
::
reduce
::
Add
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dx_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_K
&
scale_bias_diff_grid_desc_m_k
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
index_t
blkgroup_size
,
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_scale_bias_diff_k_block_tile_iteration
,
const
ScaleDataType
*
const
__restrict__
p_reduce_scale_diff
,
const
BiasDataType
*
const
__restrict__
p_reduce_bias_diff
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
DxDataType
*
const
__restrict__
p_dx
,
ScaleDataType
*
const
__restrict__
p_scale_diff
,
BiasDataType
*
const
__restrict__
p_bias_diff
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
reduce_scale_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
reduce_bias_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dy_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dx_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
inv_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
// Step 1: do final reduction for scale_diff and bias_diff and output
auto
threadwise_scale_diff_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasDiffGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
scale_bias_diff_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_bias_diff_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasDiffGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
scale_bias_diff_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_scale_diff_store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_bias_diff_store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
const
auto
reduce_scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_scale_diff
,
scale_bias_diff_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
reduce_bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_bias_diff
,
scale_bias_diff_grid_desc_m_k
.
GetElementSpaceSize
());
auto
scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale_diff
,
scale_grid_desc_m
.
GetElementSpaceSize
());
auto
bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias_diff
,
bias_grid_desc_m
.
GetElementSpaceSize
());
constexpr
auto
scale_bias_diff_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
scale_diff_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
bias_diff_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_scale_bias_diff_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_scale_diff_load_m_k
.
Run
(
scale_bias_diff_grid_desc_m_k
,
reduce_scale_diff_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_scale_diff_thread_buf
);
threadwise_bias_diff_load_m_k
.
Run
(
scale_bias_diff_grid_desc_m_k
,
reduce_bias_diff_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_bias_diff_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_scale_diff_thread_buf
,
scale_diff_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_bias_diff_thread_buf
,
bias_diff_thread_buf
);
threadwise_scale_diff_load_m_k
.
MoveSrcSliceWindow
(
scale_bias_diff_grid_desc_m_k
,
scale_bias_diff_thread_copy_step_m_k
);
threadwise_bias_diff_load_m_k
.
MoveSrcSliceWindow
(
scale_bias_diff_grid_desc_m_k
,
scale_bias_diff_thread_copy_step_m_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
scale_diff_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
bias_diff_thread_buf
(
I
));
});
threadwise_scale_diff_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_diff_thread_buf
,
scale_grid_desc_m
,
scale_diff_global_val_buf
);
threadwise_bias_diff_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_diff_thread_buf
,
bias_grid_desc_m
,
bias_diff_global_val_buf
);
// Step 2: calculate dx = 1/N * invVar * scale * (N * dy - biasDiff - scaleDiff * (x - mean)
// * invVar) and output
const
index_t
workSizePerBlock
=
K_BlockTileSize
*
num_xy_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DyDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
DySrcVectorSize
,
1
,
true
>
(
dy_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
DxDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
DxDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dx_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcDstVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_var
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
constexpr
auto
xy_thread_copy_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_xy_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
type_convert
<
AccDataType
>
(
1.0
)
/
type_convert
<
AccDataType
>
(
reduce_size
)
*
inv_var_thread_buf
[
iM
]
*
scale_thread_buf
[
iM
];
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
scale_diff_thread_buf
[
iM
];
dx_thread_buf
(
Number
<
offset
>
{})
=
multiplier
*
(
type_convert
<
AccDataType
>
(
reduce_size
)
*
dy_thread_buf
[
Number
<
offset
>
{}]
-
bias_diff_thread_buf
[
iM
]
-
tmpVal
);
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
}
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_first_half.hpp
0 → 100644
View file @
bbd19d89
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseMultiblockWelfordFirstHalf_
,
typename
XDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_multiblock_welford_first_half
(
const
XGridDesc_M_K
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
p_welford_mean
,
MeanVarDataType
*
const
p_welford_variance
,
int32_t
*
const
p_welford_count
)
{
GridwiseMultiblockWelfordFirstHalf_
::
Run
(
x_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
p_x
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
);
};
template
<
typename
XDataType
,
typename
AccDataType
,
typename
MeanVarDataType
,
typename
XGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcCountSrcVectorDim
,
index_t
XSrcCountSrcVectorSize
>
struct
GridwiseMultiblockWelfordFirstHalf
{
static_assert
((
XSrcCountSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcCountSrcVectorSize
==
0
)
||
(
XSrcCountSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcCountSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcCountSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
p_welford_mean
,
MeanVarDataType
*
const
p_welford_variance
,
int32_t
*
const
p_welford_count
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
const
index_t
blkgroup_size
=
mean_var_count_grid_desc_m_g
.
GetLength
(
I1
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcCountSrcVectorDim
,
XSrcCountSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_welford_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_welford_count_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_mean
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_variance
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
get_reduce_count_per_thread
(
block_local_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_welford
.
Run
(
x_thread_buf
,
welford_mean_thread_buf
,
welford_var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
welford_count_thread_buf
(
I
)
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_mean_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_mean_global_val_buf
);
threadwise_welford_mean_var_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_var_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_var_global_val_buf
);
threadwise_welford_count_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
welford_count_thread_buf
,
mean_var_count_grid_desc_m_g
,
welford_count_global_val_buf
);
};
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
0 → 100644
View file @
bbd19d89
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseWelfordSecondHalfReduceFirstHalf_
,
typename
XDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
XYGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasDiffGridDesc_M_G
>
__global__
void
kernel_welford_second_half_reduce_first_half
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
MeanVarCountGridDesc_M_K
mean_var_count_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_G
scale_bias_grid_desc_m_g
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
bool
haveSavedMeanInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
MeanVarDataType
*
const
__restrict__
p_out_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
ScaleDataType
*
const
__restrict__
p_reduce_scale_diff
,
BiasDataType
*
const
__restrict__
p_reduce_bias_diff
)
{
GridwiseWelfordSecondHalfReduceFirstHalf_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
mean_var_grid_desc_m
,
mean_var_count_grid_desc_m_k
,
scale_bias_grid_desc_m_g
,
blkgroup_size
,
num_xy_k_block_tile_iteration
,
num_mean_var_count_k_block_tile_iteration
,
epsilon
,
haveSavedMeanInvVar
,
p_savedMean
,
p_savedInvVar
,
p_in_welford_mean
,
p_in_welford_variance
,
p_in_welford_count
,
p_out_welford_mean
,
p_out_welford_inv_variance
,
p_x
,
p_dy
,
p_reduce_scale_diff
,
p_reduce_bias_diff
);
};
template
<
typename
XDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
XYGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasDiffGridDesc_M_G
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XDyVectorDim
,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
GridwiseWelfordSecondHalfReduceFirstHalf
{
static_assert
((
XDyVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
&&
MThreadSliceSize
%
DySrcVectorSize
==
0
)
||
(
XDyVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
&&
KThreadSliceSize
%
DySrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XDyVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford_2
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
>
;
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ck
::
reduce
::
Add
,
false
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ck
::
reduce
::
Add
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dy_grid_desc_m_k
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
MeanVarCountGridDesc_M_K
&
mean_var_count_grid_desc_m_k
,
const
ScaleBiasDiffGridDesc_M_G
&
scale_bias_diff_grid_desc_m_g
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
bool
haveSavedMeanInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_savedMean
,
const
MeanVarDataType
*
const
__restrict__
p_savedInvVar
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
const
int32_t
*
const
__restrict__
p_in_welford_count
,
MeanVarDataType
*
const
__restrict__
p_out_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_out_welford_inv_variance
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
ScaleDataType
*
const
__restrict__
p_reduce_scale_diff
,
BiasDataType
*
const
__restrict__
p_reduce_bias_diff
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
in_welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
in_welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
*
1
,
true
>
in_welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
mean_thread_buf
=
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
inv_var_thread_buf
=
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dy_thread_buf
;
// buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
tmp1_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
reduce_scale_diff_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
reduce_bias_diff_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
// Step 1: load existing mean and inv-variance do final welford reduction on mean and
// variance
if
(
haveSavedMeanInvVar
)
{
const
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_savedInvVar
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_inv_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_inv_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
}
else
{
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_variance
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
AccDataType
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_count_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_mean_var_count_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
ThreadwiseWelford
::
Run
(
in_welford_mean_thread_buf
,
in_welford_var_thread_buf
,
in_welford_count_thread_buf
,
welford_mean_thread_buf
,
welford_var_thread_buf
,
welford_count_thread_buf
);
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
welford_mean_thread_buf
(
I
),
welford_var_thread_buf
(
I
),
welford_count_thread_buf
(
I
));
});
// calculate inv-variance from variance, stored in place
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
)
/
sqrt
(
welford_var_thread_buf
[
I
]
+
epsilon
);
});
if
(
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_welford_mean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
inv_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_welford_inv_variance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
mean_global_val_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
,
mean_var_grid_desc_m
,
inv_var_global_val_buf
);
};
};
const
index_t
workSizePerBlock
=
K_BlockTileSize
*
num_xy_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DyDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyVectorDim
,
DySrcVectorSize
,
1
,
true
>
(
dy_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
constexpr
auto
xy_thread_copy_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
reduce_scale_diff_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
reduce_bias_diff_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0
);
});
// Step 2: do first-half reduction on dy and dy * (x-mean) * inv-variance
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_xy_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
tmp1_thread_buf
(
Number
<
offset
>
{})
=
norm_x
*
dy_thread_buf
[
Number
<
offset
>
{}];
});
});
ThreadwiseReduce
::
Reduce
(
tmp1_thread_buf
,
reduce_scale_diff_thread_buf
);
ThreadwiseReduce
::
Reduce
(
dy_thread_buf
,
reduce_bias_diff_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_scale_diff_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
reduce_bias_diff_thread_buf
(
I
));
});
auto
threadwise_scale_diff_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ScaleDataType
,
decltype
(
thread_buffer_desc_m_1
),
ScaleBiasDiffGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_bias_diff_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_bias_diff_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
BiasDataType
,
decltype
(
thread_buffer_desc_m_1
),
ScaleBiasDiffGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
scale_bias_diff_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
reduce_scale_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_scale_diff
,
scale_bias_diff_grid_desc_m_g
.
GetElementSpaceSize
());
auto
reduce_bias_diff_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_bias_diff
,
scale_bias_diff_grid_desc_m_g
.
GetElementSpaceSize
());
if
(
thread_k_cluster_id
==
0
)
{
threadwise_scale_diff_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_scale_diff_thread_buf
,
scale_bias_diff_grid_desc_m_g
,
reduce_scale_diff_global_val_buf
);
threadwise_bias_diff_store
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_bias_diff_thread_buf
,
scale_bias_diff_grid_desc_m_g
,
reduce_bias_diff_global_val_buf
);
};
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
View file @
bbd19d89
...
...
@@ -75,4 +75,55 @@ struct ThreadwiseWelford
int
max_count_
;
};
template
<
typename
T
,
typename
SrcMeanVarCountThreadDesc_M_K
,
typename
DstMeanVarThreadDesc_M
>
struct
ThreadwiseWelford_2
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcMeanVarCountThreadDesc_M_K
{};
static
constexpr
auto
dst_thread_desc_m
=
DstMeanVarThreadDesc_M
{};
static
constexpr
auto
src_length_m
=
src_thread_desc_m_k
.
GetLength
(
Number
<
0
>
{});
static
constexpr
auto
src_length_k
=
src_thread_desc_m_k
.
GetLength
(
Number
<
1
>
{});
static
constexpr
auto
dst_length_m
=
dst_thread_desc_m
.
GetLength
(
Number
<
0
>
{});
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
__device__
static
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int32_t
&
count_a
,
T
mean_b
,
T
var_b
,
int32_t
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a
*
count_b_over_count
;
count_a
=
count
;
}
template
<
typename
SrcMeanBufferType
,
typename
SrcVarBufferType
,
typename
SrcCountBufferType
,
typename
DstMeanBufferType
,
typename
DstVarBufferType
,
typename
DstCountBufferType
>
__device__
static
void
Run
(
const
SrcMeanBufferType
&
src_mean_buf
,
const
SrcVarBufferType
&
src_var_buf
,
const
SrcCountBufferType
&
src_count_buf
,
DstMeanBufferType
&
dst_mean_buf
,
DstVarBufferType
&
dst_var_buf
,
DstCountBufferType
&
dst_count_buf
)
{
static_for
<
0
,
src_length_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
src_length_k
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
src_offset
=
src_thread_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
Merge
(
dst_mean_buf
(
iM
),
dst_var_buf
(
iM
),
dst_count_buf
(
iM
),
src_mean_buf
[
Number
<
src_offset
>
{}],
src_var_buf
[
Number
<
src_offset
>
{}],
src_count_buf
[
Number
<
src_offset
>
{}]);
});
});
};
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment