Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4e911f3e
Commit
4e911f3e
authored
Jul 07, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
fa9da1a4
4939ee59
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1701 additions
and
115 deletions
+1701
-115
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp
...antization/conv2d_fwd_dl_perchannel_quantization_int8.cpp
+1
-1
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
...quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
+1
-1
example/41_grouped_conv_conv_fwd/CMakeLists.txt
example/41_grouped_conv_conv_fwd/CMakeLists.txt
+2
-6
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
+1
-1
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
+1
-1
include/ck/ck.hpp
include/ck/ck.hpp
+23
-0
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+13
-11
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+0
-2
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+185
-81
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
...pu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
+714
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp
...gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+6
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+48
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
...norm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
+704
-0
No files found.
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp
View file @
4e911f3e
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using
InDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
...
...
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
View file @
4e911f3e
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using
InDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
...
...
example/41_grouped_conv_conv_fwd/CMakeLists.txt
View file @
4e911f3e
...
...
@@ -13,10 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif
()
endforeach
()
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list2 AND target EQUAL 0
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
set
(
target 1
)
endif
()
endforeach
()
endif
()
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
View file @
4e911f3e
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
View file @
4e911f3e
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
include/ck/ck.hpp
View file @
4e911f3e
...
...
@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
...
...
@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...
...
@@ -173,6 +192,10 @@
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_SWDEV_3318619 0
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
...
...
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
View file @
4e911f3e
...
...
@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
namespace
ck
{
...
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
CountDataType
>
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
Merge
(
T
&
mean_a
,
T
&
var_a
,
CountDataType
&
count_a
,
T
mean_b
,
T
var_b
,
CountDataType
count_b
)
{
int
count
=
count_a
+
count_b
;
CountDataType
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
...
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a
=
count
;
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
template
<
typename
CountDataType
>
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
CountDataType
&
count
)
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
__shared__
CountDataType
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
...
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
CountDataType
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
CountDataType
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
4e911f3e
...
...
@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4e911f3e
...
...
@@ -786,12 +786,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
==
1
&&
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
0
]
%
D0sTransferSrcScalarPerVector
!=
0
)
{
std
::
cout
<<
"first"
<<
std
::
endl
;
return
false
;
}
if
(
arg
.
d0s_nl_ns_lengths_strides_
[
i
][
1
]
!=
1
&&
D0sTransferSrcScalarPerVector
!=
1
)
{
std
::
cout
<<
"second"
<<
std
::
endl
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
4e911f3e
...
...
@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_
tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
...
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_
tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
...
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 1
28
if
(
testBlkGroupSize
<=
1
28
)
// we want the blkGroupSize be not more than 1
6
if
(
testBlkGroupSize
<=
1
6
)
break
;
iterations
++
;
...
...
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
void
*
control_
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
...
...
@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size
+=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
}
return
(
workspace_size
);
...
...
@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
...
...
@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
index_t
count_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
);
count_space_sz
=
math
::
integer_least_multiple
(
count_space_sz
,
64
);
pArg_
->
control_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_count_
)
+
count_space_sz
;
index_t
control_space_sz
=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
hip_check_error
(
hipMemset
(
pArg_
->
control_
,
0
,
control_space_sz
));
};
};
...
...
@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockBatchNormForward_
=
GridwiseMultiblockBatchNormForward
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
...
...
@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
index_t
numMeanVarCountBlockTileIteration
=
(
arg
.
blkGroupSize_
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
numMeanVarCountBlockTileIteration
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
// It is found that:
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
// 2) Profiler on gfx908 could hang even though it works when running examples
// 3) Single-kernel method works on gfx1100, but the performance it not better
// than two-kernel method (due to more warps participating the barrier)
if
(
ck
::
get_device_name
()
==
"gfx90a"
)
{
const
auto
kern_multiblock_batchnorm_fwd_
=
kernel_multiblock_batchnorm_forward
<
GridwiseMultiblockBatchNormForward_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_batchnorm_fwd_
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
// for writing to mean/variance/count
// workspace by multiple workgroups
mean_var_count_grid_desc_m_k
,
// for reading from mean/variance/count
// workspace by each workgroup
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
static_cast
<
int
*>
(
arg
.
control_
),
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
}
else
{
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
}
else
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
0 → 100644
View file @
4e911f3e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeXY2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>&
xyStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleXYLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleXYStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
xyStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleXYLengths
,
tupleXYStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumBatchNormReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
workSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
workSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_g_padded
=
transform_tensor_descriptor
(
grid_desc_m_g
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_g_padded
);
};
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
math
::
integer_least_multiple
(
reduceLength
,
KThreadClusterSize
)
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeScaleBiasMeanVar1dDescriptor
(
const
std
::
array
<
index_t
,
NumInvariantDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>&
strides
)
{
const
auto
tupleLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_m
=
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_m_padded
);
};
using
XYGridDesc_M_K
=
decltype
(
MakeXY2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
ScaleBiasMeanVarGridDesc_M
=
decltype
(
MakeScaleBiasMeanVar1dDescriptor
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
p_scale
,
const
BiasDataType
*
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
double
epsilon
,
YDataType
*
p_y
,
MeanVarDataType
*
resultSaveMean
,
MeanVarDataType
*
resultSaveInvVariance
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_scale_
(
p_scale
),
p_bias_
(
p_bias
),
y_elementwise_op_
(
y_elementwise_op
),
p_y_
(
p_y
),
resultSaveMean_
(
resultSaveMean
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
yStrides
,
reduceDims
);
std
::
tie
(
invariant_length_
,
reduce_length_
)
=
get_2d_lengths
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
updateMovingAverage_
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
saveMeanInvVariance_
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance_
!=
nullptr
);
if
(
UseMultiblockInK
)
{
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 16
if
(
testBlkGroupSize
<=
16
)
break
;
iterations
++
;
};
blkGroupSize_
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration_
=
iterations
;
}
else
{
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_length_
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize_
=
(
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
blkGroupSize_
;
x_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
scale_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnScaleStrides_
);
bias_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnBiasStrides_
);
mean_var_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnMeanVarStrides_
);
}
AccDataType
epsilon_
;
AccDataType
averageFactor_
;
bool
updateMovingAverage_
;
bool
saveMeanInvVariance_
;
std
::
array
<
index_t
,
Rank
>
xyLengths_
;
std
::
array
<
index_t
,
Rank
>
xStrides_
;
std
::
array
<
index_t
,
Rank
>
yStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
p_scale_
;
const
BiasDataType
*
p_bias_
;
const
YElementwiseOp
y_elementwise_op_
;
YDataType
*
p_y_
;
MeanVarDataType
*
resultSaveMean_
;
MeanVarDataType
*
resultSaveInvVariance_
;
MeanVarDataType
*
resultRunningMean_
;
MeanVarDataType
*
resultRunningVariance_
;
long_index_t
invariant_length_
;
long_index_t
reduce_length_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
XYGridDesc_M_K
x_grid_desc_m_k_
;
XYGridDesc_M_K
y_grid_desc_m_k_
;
ScaleBiasMeanVarGridDesc_M
scale_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
bias_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
mean_var_grid_desc_m_
;
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// workspace for welford intermediate mean
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
}
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
workspace_variance_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
};
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0
;
if
(
UseMultiblockInK
&&
arg
.
blkGroupSize_
>
1
)
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForMultiblockWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
const
auto
mean_var_count_grid_desc_m_g
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountOutputMG2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
const
auto
mean_var_count_grid_desc_m_k
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountInputMK2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordSecondHalfBatchNormForwardFinal_
=
GridwiseWelfordSecondHalfBatchNormForwardFinal
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
}
else
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForBlockwiseWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
using
GridwiseBatchNormForwardWithBlockwiseWelford_
=
GridwiseBatchNormForwardWithBlockwiseWelford
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_batchnorm_fwd
=
kernel_batchnorm_forward_with_blockwise_welford
<
GridwiseBatchNormForwardWithBlockwiseWelford_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_batchnorm_fwd
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
pArg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
pArg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
pArg
)
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
XSrcYDstVectorDim
==
0
)
{
if
(
pArg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
yStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
}
else
{
if
(
pArg_
->
xStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
yStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
};
if
(
pArg_
->
bnScaleStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
ScaleSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnBiasStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
BiasSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
ScaleSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
BiasSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnMeanVarStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
MeanVarSrcDstVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
MeanVarSrcDstVectorSize
!=
0
)
return
false
;
bool
is_valid
=
true
;
static_for
<
0
,
NumInvariantDim
,
1
>
{}([
&
](
auto
I
)
{
if
(
pArg_
->
xyLengths_
[
I
]
!=
pArg_
->
bnScaleBiasMeanVarLengths_
[
I
])
is_valid
=
false
;
});
if
(
!
is_valid
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_scale
,
const
void
*
p_bias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
,
double
averageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
y_elementwise_op
,
epsilon
,
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
MeanVarDataType
*>
(
resultSaveMean
),
static_cast
<
MeanVarDataType
*>
(
resultSaveInvVariance
),
averageFactor
,
static_cast
<
MeanVarDataType
*>
(
resultRunningMean
),
static_cast
<
MeanVarDataType
*>
(
resultRunningVariance
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchNormFwdImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XSrcYDstVectorDim_"
<<
XSrcYDstVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_scale_"
<<
ScaleSrcVectorSize
<<
"_bias_"
<<
BiasSrcVectorSize
<<
"_mean_var_"
<<
MeanVarSrcDstVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
4e911f3e
...
...
@@ -76,7 +76,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
LoopScheduler
LoopSched
=
make_default_loop_scheduler
();
static
constexpr
PipelineVersion
PipelineVer
=
PipelineVersion
::
v
2
;
static
constexpr
PipelineVersion
PipelineVer
=
PipelineVersion
::
v
1
;
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
4e911f3e
...
...
@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_k_wos_lengths
[
0
]},
num_gemm_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
...
...
@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
// number of GEMM
num_gemm_
=
YTilde
*
XTilde
;
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
...
...
@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
void
Print
()
const
{
for
(
index
_t
i
=
0
;
i
<
num_gemm_
;
i
++
)
for
(
std
::
size
_t
i
=
0
;
i
<
a_grid_desc_ak0_m_ak1_container_
.
size
()
;
i
++
)
{
std
::
cout
<<
"a_grid_desc_ak0_m_ak1_container_"
<<
a_grid_desc_ak0_m_ak1_container_
[
i
]
<<
std
::
endl
;
...
...
@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition
index_t
num_group_
;
index_t
num_gemm_
;
std
::
vector
<
AGridDesc_M_K
>
a_grid_desc_m_k_container_
;
std
::
vector
<
BGridDesc_N_K
>
b_grid_desc_n_k_container_
;
std
::
vector
<
DsGridDesc_M_N
>
ds_grid_desc_m_n_container_
;
...
...
@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
float
ave_time
=
0
;
for
(
index
_t
i
=
0
;
i
<
arg
.
num_gemm_
;
i
++
)
for
(
std
::
size
_t
i
=
0
;
i
<
arg
.
a_grid_desc_ak0_m_ak1_container_
.
size
()
;
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_container_
[
i
],
arg
.
b_grid_desc_n_k_container_
[
i
],
...
...
@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// vector load for A matrix from global memory to LDS
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
{
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
ConvK
%
ABlockTransferSrcScalarPerVector
==
0
))
{
...
...
@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// vector store for E
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
)
if
constexpr
(
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
ELayout
,
tensor_layout
::
convolution
::
NHWGC
>
)
{
// vector store C matrix into global memory
if
(
!
(
ConvC
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
4e911f3e
File moved
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
4e911f3e
...
...
@@ -6,6 +6,7 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -81,6 +82,36 @@ struct PassThrough
y
=
x
;
}
#endif
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
f8_t
>
(
float
&
y
,
const
f8_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
f8_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
f8_t
>
(
half_t
&
y
,
const
f8_t
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
half_t
>
(
f8_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert
<
f8_t
>
(
x
);
}
};
struct
UnaryConvert
...
...
@@ -109,6 +140,23 @@ struct ConvertBF16RTN
}
};
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
f8_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
struct
Scale
{
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
0 → 100644
View file @
4e911f3e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_synchronization.hpp"
namespace
ck
{
template
<
typename
GridwiseMultiblockBatchNormForward_
,
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_multiblock_batchnorm_forward
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
mean_var_count_grid_desc_m_g
,
const
MeanVarCountGridDesc_M_K
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
,
int32_t
*
const
__restrict__
p_control
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
GridwiseMultiblockBatchNormForward_
::
Run
(
x_grid_desc_m_k
,
y_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
mean_var_count_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
epsilon
,
p_x
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
,
p_control
,
p_scale
,
p_bias
,
y_elementwise_op
,
p_y
,
updateMovingAverage
,
averageFactor
,
resultRunningMean
,
resultRunningVariance
,
saveMeanInvVariance
,
resultSaveMean
,
resultSaveInvVariance
);
};
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseMultiblockBatchNormForward
{
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadwiseWelford1
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
ThreadwiseWelford2
=
ThreadwiseWelfordMerge
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford1
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
BlockwiseWelford2
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
MeanVarCountGridDesc_M_K
&
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
,
int32_t
*
const
__restrict__
p_control
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
using
ck
::
math
::
sqrt
;
const
index_t
blkgroup_size
=
mean_var_count_grid_desc_m_g
.
GetLength
(
I1
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
if
(
block_local_id
==
0
)
gms_init
(
BlockSize
/
warpSize
*
blkgroup_size
,
&
p_control
[
blkgroup_id
*
2
]);
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
tmp_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
tmp_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
tmp_count_thread_buf
;
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
xy_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
// Step 1: each workgroup does local welford reduction
auto
threadwise_welford_1
=
ThreadwiseWelford1
();
threadwise_welford_1
.
max_count_
=
get_reduce_count_per_thread
(
block_local_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
threadwise_welford_1
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
count_thread_buf
(
I
)
=
threadwise_welford_1
.
cur_count_
;
BlockwiseWelford1
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count_thread_buf
(
I
));
});
// Step 2: each workgroup writes its local welford result to workspace memory
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_mean
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_variance
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_count
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
threadwise_mean_var_store_m_g
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_count_store_m_g
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_mean_var_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
,
mean_var_count_grid_desc_m_g
,
mean_global_val_buf
);
threadwise_mean_var_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
var_thread_buf
,
mean_var_count_grid_desc_m_g
,
var_global_val_buf
);
threadwise_count_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
count_thread_buf
,
mean_var_count_grid_desc_m_g
,
count_global_val_buf
);
};
gms_barrier
(
&
p_control
[
blkgroup_id
*
2
]);
if
(
block_local_id
==
0
)
gms_reset
(
&
p_control
[
blkgroup_id
*
2
]);
// Step 3: each workgroup reads welford results from workspace memory and does final welford
// reduction
auto
threadwise_mean_var_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_count_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
count_thread_buf
(
I
)
=
0
;
});
constexpr
auto
mean_var_count_read_fwd_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
);
int32_t
reducedSize
=
0
;
while
(
reducedSize
<
blkgroup_size
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_mean_thread_buf
);
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_count_thread_buf
);
ThreadwiseWelford2
::
Run
(
tmp_mean_thread_buf
,
tmp_var_thread_buf
,
tmp_count_thread_buf
,
mean_thread_buf
,
var_thread_buf
,
count_thread_buf
);
reducedSize
+=
KThreadClusterSize
;
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_read_fwd_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_read_fwd_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford2
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count_thread_buf
(
I
));
});
// Step 4: do normalization using the mean/variance
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
YElementwiseOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
y_elementwise_op
);
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_bias_load
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasSrcVectorSize
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
bias_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_bias_load
.
Run
(
bias_grid_desc_m
,
bias_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_thread_buf
);
threadwise_x_load
.
SetSrcSliceOrigin
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
scale_thread_buf
[
Number
<
iM
>
{}]
/
sqrt
(
var_thread_buf
[
iM
]
+
epsilon
);
AccDataType
fused_mean_bias
=
bias_thread_buf
[
Number
<
iM
>
{}]
-
mean_thread_buf
[
iM
]
*
multiplier
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset
>
{})
=
x_thread_buf
[
Number
<
offset
>
{}]
*
multiplier
+
fused_mean_bias
;
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
}
// Step 5: update the moving average of mean and variance (optional)
if
(
updateMovingAverage
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_var_thread_buf
;
auto
running_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
running_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
averageFactor
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
running_mean_thread_buf
(
I
)
=
running_mean_thread_buf
[
I
]
*
oneMinusAverageFactor
+
mean_thread_buf
[
I
]
*
averageFactor
;
running_var_thread_buf
(
I
)
=
running_var_thread_buf
[
I
]
*
oneMinusAverageFactor
+
var_thread_buf
[
I
]
*
averageFactor
;
});
auto
threadwise_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
,
mean_var_grid_desc_m
,
running_mean_global_buf
);
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
,
mean_var_grid_desc_m
,
running_var_global_buf
);
};
// Step 6: save mean and inv-variance (optional)
if
(
saveMeanInvVariance
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
auto
result_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
result_inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveInvVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
sqrt
(
epsilon
+
var_thread_buf
[
I
]);
});
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
result_mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
var_thread_buf
,
mean_var_grid_desc_m
,
result_inv_var_global_buf
);
};
}
};
// namespace ck
}
// namespace ck
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment