Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
349552e2
Commit
349552e2
authored
Jul 10, 2023
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl
parents
daa98dae
87f2bbcf
Changes
105
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1722 additions
and
110 deletions
+1722
-110
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
...quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
+1
-1
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
+1
-1
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
...mm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
+1
-1
include/ck/ck.hpp
include/ck/ck.hpp
+19
-0
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+13
-11
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+185
-81
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
...pu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
+714
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp
...gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
...device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
...mpl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
+0
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
...norm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
+704
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
...orm_multiblock/gridwise_multiblock_welford_first_half.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp
..._welford_second_half_batchnorm_forward_final_obsolete.hpp
+9
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+6
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+60
-2
No files found.
example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp
View file @
349552e2
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
using
InDataType
=
int8_t
;
using
WeiDataType
=
int8_t
;
...
...
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp
View file @
349552e2
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp
View file @
349552e2
...
...
@@ -8,7 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
...
...
include/ck/ck.hpp
View file @
349552e2
...
...
@@ -27,6 +27,21 @@
#define CK_WAVELET_MIN_BLOCK_PER_CU 2
#endif
// kernel attribute: amdgpu_waves_per_eu()
#ifdef CK_USE_WAVES_PER_EU
// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute
#ifndef CK_MIN_WAVES_PER_EU
#define CK_MIN_WAVES_PER_EU 0
#endif
#ifndef CK_MAX_WAVES_PER_EU
#define CK_MAX_WAVES_PER_EU 0
#endif
#else
#define CK_USE_WAVES_PER_EU 0
#endif
// buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
...
...
@@ -148,6 +163,10 @@
#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1
// experimental feature: add instances using pipeline v2
#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1
// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy)
#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0
#endif
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
...
...
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
View file @
349552e2
...
...
@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
namespace
ck
{
...
...
@@ -35,10 +35,11 @@ struct BlockwiseWelford
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
CountDataType
>
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
Merge
(
T
&
mean_a
,
T
&
var_a
,
CountDataType
&
count_a
,
T
mean_b
,
T
var_b
,
CountDataType
count_b
)
{
int
count
=
count_a
+
count_b
;
CountDataType
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
...
...
@@ -46,11 +47,12 @@ struct BlockwiseWelford
count_a
=
count
;
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
template
<
typename
CountDataType
>
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
CountDataType
&
count
)
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
__shared__
CountDataType
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
...
...
@@ -76,13 +78,13 @@ struct BlockwiseWelford
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
CountDataType
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
CountDataType
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
349552e2
...
...
@@ -4,7 +4,7 @@
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/
reduction_common
.hpp"
#include "ck/utility/
get_shift
.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
349552e2
...
...
@@ -10,12 +10,14 @@
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_
tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
...
@@ -132,9 +134,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_
tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
...
...
@@ -244,8 +246,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 1
28
if
(
testBlkGroupSize
<=
1
28
)
// we want the blkGroupSize be not more than 1
6
if
(
testBlkGroupSize
<=
1
6
)
break
;
iterations
++
;
...
...
@@ -319,6 +321,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
void
*
control_
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
...
...
@@ -340,6 +344,11 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
// workspace for barrier objects, each barrier object consists of two integers
// TODO: allocate barrier object memory globally to reuse it by other operators
workspace_size
+=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
}
return
(
workspace_size
);
...
...
@@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
...
...
@@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
index_t
count_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
);
count_space_sz
=
math
::
integer_least_multiple
(
count_space_sz
,
64
);
pArg_
->
control_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_count_
)
+
count_space_sz
;
index_t
control_space_sz
=
(
pArg_
->
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
sizeof
(
int
)
*
2
;
hip_check_error
(
hipMemset
(
pArg_
->
control_
,
0
,
control_space_sz
));
};
};
...
...
@@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockBatchNormForward_
=
GridwiseMultiblockBatchNormForward
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
...
...
@@ -441,78 +487,136 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
index_t
numMeanVarCountBlockTileIteration
=
(
arg
.
blkGroupSize_
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
numMeanVarCountBlockTileIteration
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
// It is found that:
// 1) gfx1030 does not support the GLC enabled vector load/store, so using the
// two-kernel method for gfx1030
// 2) Profiler on gfx908 could hang even though it works when running examples
// 3) Single-kernel method works on gfx1100, but the performance it not better
// than two-kernel method (due to more warps participating the barrier)
if
(
ck
::
get_device_name
()
==
"gfx90a"
)
{
const
auto
kern_multiblock_batchnorm_fwd_
=
kernel_multiblock_batchnorm_forward
<
GridwiseMultiblockBatchNormForward_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_batchnorm_fwd_
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
// for writing to mean/variance/count
// workspace by multiple workgroups
mean_var_count_grid_desc_m_k
,
// for reading from mean/variance/count
// workspace by each workgroup
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
static_cast
<
int
*>
(
arg
.
control_
),
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
}
else
{
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
}
else
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp
0 → 100644
View file @
349552e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeXY2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>&
xyStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleXYLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleXYStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
xyStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleXYLengths
,
tupleXYStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumBatchNormReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
workSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
workSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeMeanVarCountOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor
(
make_tuple
(
invariantLength
,
blkGroupSize
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_g_padded
=
transform_tensor_descriptor
(
grid_desc_m_g
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_g_padded
);
};
static
auto
MakeMeanVarCountInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor
(
make_tuple
(
invariantLength
,
reduceLength
),
make_tuple
(
1
,
invariantLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
math
::
integer_least_multiple
(
reduceLength
,
KThreadClusterSize
)
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeScaleBiasMeanVar1dDescriptor
(
const
std
::
array
<
index_t
,
NumInvariantDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>&
strides
)
{
const
auto
tupleLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_m
=
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_m_padded
);
};
using
XYGridDesc_M_K
=
decltype
(
MakeXY2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
ScaleBiasMeanVarGridDesc_M
=
decltype
(
MakeScaleBiasMeanVar1dDescriptor
({
1
},
{
1
}));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
p_scale
,
const
BiasDataType
*
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
double
epsilon
,
YDataType
*
p_y
,
MeanVarDataType
*
resultSaveMean
,
MeanVarDataType
*
resultSaveInvVariance
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_scale_
(
p_scale
),
p_bias_
(
p_bias
),
y_elementwise_op_
(
y_elementwise_op
),
p_y_
(
p_y
),
resultSaveMean_
(
resultSaveMean
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
yStrides
,
reduceDims
);
std
::
tie
(
invariant_length_
,
reduce_length_
)
=
get_2d_lengths
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
updateMovingAverage_
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
saveMeanInvVariance_
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance_
!=
nullptr
);
if
(
UseMultiblockInK
)
{
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 16
if
(
testBlkGroupSize
<=
16
)
break
;
iterations
++
;
};
blkGroupSize_
=
(
reduce_length_
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration_
=
iterations
;
}
else
{
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_length_
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize_
=
(
invariant_length_
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
blkGroupSize_
;
x_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeXY2dDescriptor
(
xyLengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
scale_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnScaleStrides_
);
bias_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnBiasStrides_
);
mean_var_grid_desc_m_
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnMeanVarStrides_
);
}
AccDataType
epsilon_
;
AccDataType
averageFactor_
;
bool
updateMovingAverage_
;
bool
saveMeanInvVariance_
;
std
::
array
<
index_t
,
Rank
>
xyLengths_
;
std
::
array
<
index_t
,
Rank
>
xStrides_
;
std
::
array
<
index_t
,
Rank
>
yStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
p_scale_
;
const
BiasDataType
*
p_bias_
;
const
YElementwiseOp
y_elementwise_op_
;
YDataType
*
p_y_
;
MeanVarDataType
*
resultSaveMean_
;
MeanVarDataType
*
resultSaveInvVariance_
;
MeanVarDataType
*
resultRunningMean_
;
MeanVarDataType
*
resultRunningVariance_
;
long_index_t
invariant_length_
;
long_index_t
reduce_length_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
XYGridDesc_M_K
x_grid_desc_m_k_
;
XYGridDesc_M_K
y_grid_desc_m_k_
;
ScaleBiasMeanVarGridDesc_M
scale_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
bias_grid_desc_m_
;
ScaleBiasMeanVarGridDesc_M
mean_var_grid_desc_m_
;
void
*
workspace_mean_
;
void
*
workspace_variance_
;
void
*
workspace_count_
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// workspace for welford intermediate mean
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
int32_t
)
+
64
;
}
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize_
>
1
)
{
// setup buffer used for intermediate welford mean
pArg_
->
workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
workspace_variance_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
pArg_
->
invariant_length_
*
pArg_
->
blkGroupSize_
*
sizeof
(
MeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welfor count
pArg_
->
workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance_
)
+
variance_space_sz
;
};
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0
;
if
(
UseMultiblockInK
&&
arg
.
blkGroupSize_
>
1
)
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForMultiblockWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
const
auto
mean_var_count_grid_desc_m_g
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountOutputMG2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
const
auto
mean_var_count_grid_desc_m_k
=
DeviceBatchNormFwdImpl
::
MakeMeanVarCountInputMK2dDescriptor
(
arg
.
invariant_length_
,
arg
.
blkGroupSize_
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
>
;
using
GridwiseWelfordSecondHalfBatchNormForwardFinal_
=
GridwiseWelfordSecondHalfBatchNormForwardFinal
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
const
auto
kern_welford_second_half_batchnorm_forward_final
=
kernel_welford_second_half_batchnorm_forward_final
<
GridwiseWelfordSecondHalfBatchNormForwardFinal_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_batchnorm_forward_final
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
mean_var_count_grid_desc_m_k
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean_
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance_
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count_
),
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
}
else
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForBlockwiseWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
numBlockTileIteration_
,
arg
.
reduce_length_
);
using
GridwiseBatchNormForwardWithBlockwiseWelford_
=
GridwiseBatchNormForwardWithBlockwiseWelford
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
YDstVectorSize
,
ScaleSrcVectorSize
,
BiasSrcVectorSize
,
MeanVarSrcDstVectorSize
>
;
const
auto
kern_batchnorm_fwd
=
kernel_batchnorm_forward_with_blockwise_welford
<
GridwiseBatchNormForwardWithBlockwiseWelford_
,
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasMeanVarGridDesc_M
,
ScaleBiasMeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_batchnorm_fwd
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
scale_grid_desc_m_
,
arg
.
bias_grid_desc_m_
,
arg
.
mean_var_grid_desc_m_
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_scale_
,
arg
.
p_bias_
,
arg
.
y_elementwise_op_
,
arg
.
p_y_
,
arg
.
updateMovingAverage_
,
// true or false
arg
.
averageFactor_
,
arg
.
resultRunningMean_
,
arg
.
resultRunningVariance_
,
arg
.
saveMeanInvVariance_
,
// true or false
arg
.
resultSaveMean_
,
arg
.
resultSaveInvVariance_
);
};
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
pArg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
pArg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
pArg
)
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
XSrcYDstVectorDim
==
0
)
{
if
(
pArg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
yStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
}
else
{
if
(
pArg_
->
xStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
yStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
return
false
;
};
if
(
pArg_
->
bnScaleStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
ScaleSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnBiasStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
BiasSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
ScaleSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
BiasSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnMeanVarStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
MeanVarSrcDstVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
MeanVarSrcDstVectorSize
!=
0
)
return
false
;
bool
is_valid
=
true
;
static_for
<
0
,
NumInvariantDim
,
1
>
{}([
&
](
auto
I
)
{
if
(
pArg_
->
xyLengths_
[
I
]
!=
pArg_
->
bnScaleBiasMeanVarLengths_
[
I
])
is_valid
=
false
;
});
if
(
!
is_valid
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_scale
,
const
void
*
p_bias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
,
double
averageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
BiasDataType
*>
(
p_bias
),
y_elementwise_op
,
epsilon
,
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
MeanVarDataType
*>
(
resultSaveMean
),
static_cast
<
MeanVarDataType
*>
(
resultSaveInvVariance
),
averageFactor
,
static_cast
<
MeanVarDataType
*>
(
resultRunningMean
),
static_cast
<
MeanVarDataType
*>
(
resultRunningVariance
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchNormFwdImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XSrcYDstVectorDim_"
<<
XSrcYDstVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_scale_"
<<
ScaleSrcVectorSize
<<
"_bias_"
<<
BiasSrcVectorSize
<<
"_mean_var_"
<<
MeanVarSrcDstVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
349552e2
File moved
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
349552e2
File moved
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp
View file @
349552e2
File moved
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
349552e2
File moved
include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
View file @
349552e2
File moved
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp
0 → 100644
View file @
349552e2
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_synchronization.hpp"
namespace
ck
{
template
<
typename
GridwiseMultiblockBatchNormForward_
,
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
>
__global__
void
kernel_multiblock_batchnorm_forward
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
mean_var_count_grid_desc_m_g
,
const
MeanVarCountGridDesc_M_K
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
,
int32_t
*
const
__restrict__
p_control
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
GridwiseMultiblockBatchNormForward_
::
Run
(
x_grid_desc_m_k
,
y_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
mean_var_count_grid_desc_m_k
,
scale_grid_desc_m
,
bias_grid_desc_m
,
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
num_k_block_tile_iteration
,
epsilon
,
p_x
,
p_welford_mean
,
p_welford_variance
,
p_welford_count
,
p_control
,
p_scale
,
p_bias
,
y_elementwise_op
,
p_y
,
updateMovingAverage
,
averageFactor
,
resultRunningMean
,
resultRunningVariance
,
saveMeanInvVariance
,
resultSaveMean
,
resultSaveInvVariance
);
};
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
MeanVarCountGridDesc_M_G
,
typename
MeanVarCountGridDesc_M_K
,
typename
ScaleBiasGridDesc_M
,
typename
MeanVarGridDesc_M
,
typename
GetReduceCountPerThreadFunctor
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseMultiblockBatchNormForward
{
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
XSrcYDstVectorDim
==
0
&&
MThreadSliceSize
%
YDstVectorSize
==
0
)
||
(
XSrcYDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadwiseWelford1
=
ThreadwiseWelford
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
>
;
using
ThreadwiseWelford2
=
ThreadwiseWelfordMerge
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford1
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
false
>
;
using
BlockwiseWelford2
=
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
true
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
y_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
MeanVarCountGridDesc_M_K
&
mean_var_count_grid_desc_m_k
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
bias_grid_desc_m
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x
,
MeanVarDataType
*
const
__restrict__
p_welford_mean
,
MeanVarDataType
*
const
__restrict__
p_welford_variance
,
int32_t
*
const
__restrict__
p_welford_count
,
int32_t
*
const
__restrict__
p_control
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
BiasDataType
*
const
__restrict__
p_bias
,
const
YElementwiseOp
y_elementwise_op
,
YDataType
*
const
__restrict__
p_y
,
bool
updateMovingAverage
,
AccDataType
averageFactor
,
MeanVarDataType
*
const
__restrict__
resultRunningMean
,
MeanVarDataType
*
const
__restrict__
resultRunningVariance
,
bool
saveMeanInvVariance
,
MeanVarDataType
*
const
__restrict__
resultSaveMean
,
MeanVarDataType
*
const
__restrict__
resultSaveInvVariance
)
{
using
ck
::
math
::
sqrt
;
const
index_t
blkgroup_size
=
mean_var_count_grid_desc_m_g
.
GetLength
(
I1
);
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
if
(
block_local_id
==
0
)
gms_init
(
BlockSize
/
warpSize
*
blkgroup_size
,
&
p_control
[
blkgroup_id
*
2
]);
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
tmp_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
tmp_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
tmp_count_thread_buf
;
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
constexpr
auto
xy_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
const
auto
x_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
// Step 1: each workgroup does local welford reduction
auto
threadwise_welford_1
=
ThreadwiseWelford1
();
threadwise_welford_1
.
max_count_
=
get_reduce_count_per_thread
(
block_local_id
,
thread_k_cluster_id
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
threadwise_welford_1
.
Run
(
x_thread_buf
,
mean_thread_buf
,
var_thread_buf
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
count_thread_buf
(
I
)
=
threadwise_welford_1
.
cur_count_
;
BlockwiseWelford1
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count_thread_buf
(
I
));
});
// Step 2: each workgroup writes its local welford result to workspace memory
auto
mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_mean
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_variance
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_welford_count
,
mean_var_count_grid_desc_m_g
.
GetElementSpaceSize
());
auto
threadwise_mean_var_store_m_g
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
auto
threadwise_count_store_m_g
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_buffer_desc_m_1
),
MeanVarCountGridDesc_M_G
,
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_count_grid_desc_m_g
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
),
PassThroughOp
{});
if
(
thread_k_cluster_id
==
0
)
{
threadwise_mean_var_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
mean_thread_buf
,
mean_var_count_grid_desc_m_g
,
mean_global_val_buf
);
threadwise_mean_var_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
var_thread_buf
,
mean_var_count_grid_desc_m_g
,
var_global_val_buf
);
threadwise_count_store_m_g
.
Run
(
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
count_thread_buf
,
mean_var_count_grid_desc_m_g
,
count_global_val_buf
);
};
gms_barrier
(
&
p_control
[
blkgroup_id
*
2
]);
if
(
block_local_id
==
0
)
gms_reset
(
&
p_control
[
blkgroup_id
*
2
]);
// Step 3: each workgroup reads welford results from workspace memory and does final welford
// reduction
auto
threadwise_mean_var_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_count_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
MeanVarCountGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
0
,
1
,
1
,
true
>
(
mean_var_count_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
count_thread_buf
(
I
)
=
0
;
});
constexpr
auto
mean_var_count_read_fwd_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
);
int32_t
reducedSize
=
0
;
while
(
reducedSize
<
blkgroup_size
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
mean_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_mean_thread_buf
);
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
var_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
count_global_val_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
tmp_count_thread_buf
);
ThreadwiseWelford2
::
Run
(
tmp_mean_thread_buf
,
tmp_var_thread_buf
,
tmp_count_thread_buf
,
mean_thread_buf
,
var_thread_buf
,
count_thread_buf
);
reducedSize
+=
KThreadClusterSize
;
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_read_fwd_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_read_fwd_step_m_k
);
};
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford2
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count_thread_buf
(
I
));
});
// Step 4: do normalization using the mean/variance
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
bias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
YElementwiseOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XSrcYDstVectorDim
,
YDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
y_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
y_elementwise_op
);
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_bias_load
=
ThreadwiseTensorSliceTransfer_v2
<
BiasDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
BiasSrcVectorSize
,
1
,
true
>
(
bias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
scale_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
bias_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_bias
,
bias_grid_desc_m
.
GetElementSpaceSize
());
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_bias_load
.
Run
(
bias_grid_desc_m
,
bias_global_val_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
bias_thread_buf
);
threadwise_x_load
.
SetSrcSliceOrigin
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
scale_thread_buf
[
Number
<
iM
>
{}]
/
sqrt
(
var_thread_buf
[
iM
]
+
epsilon
);
AccDataType
fused_mean_bias
=
bias_thread_buf
[
Number
<
iM
>
{}]
-
mean_thread_buf
[
iM
]
*
multiplier
;
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset
>
{})
=
x_thread_buf
[
Number
<
offset
>
{}]
*
multiplier
+
fused_mean_bias
;
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
xy_copy_fwd_step_m_k
);
}
// Step 5: update the moving average of mean and variance (optional)
if
(
updateMovingAverage
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
running_var_thread_buf
;
auto
running_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
running_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultRunningVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
running_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
averageFactor
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
running_mean_thread_buf
(
I
)
=
running_mean_thread_buf
[
I
]
*
oneMinusAverageFactor
+
mean_thread_buf
[
I
]
*
averageFactor
;
running_var_thread_buf
(
I
)
=
running_var_thread_buf
[
I
]
*
oneMinusAverageFactor
+
var_thread_buf
[
I
]
*
averageFactor
;
});
auto
threadwise_mean_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_mean_thread_buf
,
mean_var_grid_desc_m
,
running_mean_global_buf
);
threadwise_mean_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
running_var_thread_buf
,
mean_var_grid_desc_m
,
running_var_global_buf
);
};
// Step 6: save mean and inv-variance (optional)
if
(
saveMeanInvVariance
&&
block_local_id
==
0
&&
thread_k_cluster_id
==
0
)
{
auto
result_mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveMean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
auto
result_inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
resultSaveInvVariance
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
// calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
sqrt
(
epsilon
+
var_thread_buf
[
I
]);
});
auto
threadwise_mean_inv_var_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanVarDataType
,
decltype
(
thread_buffer_desc_m
),
MeanVarGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
,
mean_var_grid_desc_m
,
result_mean_global_buf
);
threadwise_mean_inv_var_store
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
var_thread_buf
,
mean_var_grid_desc_m
,
result_inv_var_global_buf
);
};
}
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
View file @
349552e2
...
...
@@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
...
...
@@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf
PassThroughOp
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp
→
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final
_obsolete
.hpp
View file @
349552e2
...
...
@@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
...
...
@@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final(
mean_var_grid_desc_m
,
blkgroup_size
,
num_xy_k_block_tile_iteration
,
num_mean_var_count_k_block_tile_iteration
,
epsilon
,
p_in_welford_mean
,
p_in_welford_variance
,
...
...
@@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
index_t
blkgroup_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_mean_var_count_k_block_tile_iteration
,
AccDataType
epsilon
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_mean
,
const
MeanVarDataType
*
const
__restrict__
p_in_welford_variance
,
...
...
@@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
1
,
true
>
(
...
...
@@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
0
,
1
,
1
,
true
>
(
...
...
@@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count
,
mean_var_count_grid_desc_m_k
.
GetElementSpaceSize
());
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
// Step 1: do final welford reduction to get mean and variance
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_count_thread_buf
(
I
)
=
0
;
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_mean_var_count_k_block_tile_iteration
;
++
reducedTiles
)
constexpr
auto
mean_var_count_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
);
int32_t
reducedSize
=
0
;
while
(
reducedSize
<
blkgroup_size
)
{
threadwise_mean_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_k
,
welford_mean_global_val_buf
,
...
...
@@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
welford_var_thread_buf
,
welford_count_thread_buf
);
reducedSize
+=
KThreadClusterSize
;
threadwise_mean_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
mean_var_count_thread_copy_step_m_k
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_k
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
349552e2
...
...
@@ -3,6 +3,8 @@
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
349552e2
...
...
@@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2
do
{
#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
__builtin_amdgcn_iglp_opt
(
CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT
);
#endif
block_sync_lds
();
// GEMM i
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
349552e2
...
...
@@ -27,6 +27,9 @@ template <typename GridwiseGemm,
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
CK_MIN_WAVES_PER_EU
,
CK_MAX_WAVES_PER_EU
)))
#endif
kernel_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -60,6 +63,9 @@ template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
CK_MIN_WAVES_PER_EU
,
CK_MAX_WAVES_PER_EU
)))
#endif
kernel_gemm_xdlops_v2r3
(
const
typename
GridwiseGemm
::
Argument
karg
)
{
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
349552e2
...
...
@@ -29,7 +29,9 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8
,
mfma_i32_32x32x16i8
,
mfma_i32_16x16x32i8
,
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
,
mfma_f32_32x32x16f8f8
,
mfma_f32_16x16x32f8f8
};
template
<
MfmaInstr
instr
>
...
...
@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_32x32x16f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x32f8f8
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
{
...
...
@@ -594,6 +640,18 @@ struct MfmaSelector
}
#endif
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
...
...
@@ -794,7 +852,7 @@ struct XdlopsGemm
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment