Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e7be2fe8
Unverified
Commit
e7be2fe8
authored
Feb 10, 2023
by
pmaybank
Committed by
GitHub
Feb 10, 2023
Browse files
Merge branch 'develop' into sphinx_doc
parents
f68fa79a
f7d28f3e
Changes
343
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1638 additions
and
168 deletions
+1638
-168
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+18
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+5
-3
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
...ration/gpu/grid/gridwise_normalization_naive_variance.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
...tion/gpu/grid/gridwise_normalization_welford_variance.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+59
-84
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
...ude/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
+59
-0
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+507
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+12
-1
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+288
-0
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+30
-44
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+6
-0
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+199
-0
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+1
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+20
-4
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+5
-5
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+1
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp
...nce_tensor_operation/cpu/reference_batchnorm_backward.hpp
+412
-0
No files found.
Too many changes to show.
To preserve performance only
343 of 343+
files are displayed.
Plain diff
Email patch
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
e7be2fe8
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...
@@ -109,7 +109,9 @@ template <index_t BlockSize,
...
@@ -109,7 +109,9 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -126,7 +128,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
...
@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -423,18 +426,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatA
B
,
FloatA
cc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
)
,
decltype
(
a
_block_desc_k0_
m
_k1
),
decltype
(
b
_block_desc_k0_
n
_k1
),
decltype
(
b_block_desc_k0_n_k1
)
,
MPerXDL
,
M
PerXDL
,
N
PerXDL
,
NPerXDL
,
MXdlPerWave
,
M
XdlPerWave
,
N
XdlPerWave
,
NXdlPerWave
,
K1
,
K1
>
{}
;
LoopSched
>
()
;
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
e7be2fe8
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...
@@ -117,7 +117,8 @@ template <
...
@@ -117,7 +117,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -137,7 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
e7be2fe8
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp"
...
@@ -123,7 +123,8 @@ template <
...
@@ -123,7 +123,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -140,7 +141,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
e7be2fe8
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
v1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_
selector
.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp"
...
@@ -132,7 +132,8 @@ template <
...
@@ -132,7 +132,8 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumGemmKPrefetchStage
=
1
>
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -149,7 +150,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_
layernorm
_naive_variance.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_
normalization
_naive_variance.hpp
View file @
e7be2fe8
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
namespace
ck
{
namespace
ck
{
// Y =
LayerNorm
(X, Beta, Gamma)
// Y =
Normalization
(X, Beta, Gamma)
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
...
@@ -36,7 +36,7 @@ template <typename XDataType,
...
@@ -36,7 +36,7 @@ template <typename XDataType,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
struct
Gridwise
Layernorm
NaiveVariance_mk_to_mk
struct
Gridwise
Normalization
NaiveVariance_mk_to_mk
{
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_
layernorm
_welford_variance.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_
normalization
_welford_variance.hpp
View file @
e7be2fe8
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
namespace
ck
{
namespace
ck
{
// Y =
LayerNorm
(X, Beta, Gamma)
// Y =
Normalization
(X, Beta, Gamma)
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
...
@@ -33,7 +33,7 @@ template <typename XDataType,
...
@@ -33,7 +33,7 @@ template <typename XDataType,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
bool
SweepOnce
>
bool
SweepOnce
>
struct
Gridwise
Layernorm
WelfordVariance_mk_to_mk
struct
Gridwise
Normalization
WelfordVariance_mk_to_mk
{
{
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
static_assert
((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
),
...
@@ -319,7 +319,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -319,7 +319,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
__builtin_amdgcn_
sqrt
f
(
var_thread_buf
(
iM
)
+
epsilon
);
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
e7be2fe8
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp
View file @
e7be2fe8
...
@@ -17,33 +17,24 @@ template <typename GridwiseSparseEmbedding,
...
@@ -17,33 +17,24 @@ template <typename GridwiseSparseEmbedding,
typename
BetaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutType
,
typename
OutGridDesc
>
typename
OutGridDesc
,
typename
EmbElementwiseOperation
,
ck
::
index_t
NumEmbeddings
>
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
__global__
void
kernel_sparse_embedding3_forward_layernorm
(
OutType
*
p_out
,
__global__
void
kernel_sparse_embeddings_forward_layernorm
(
const
EmbType
*
p_emb_a
,
OutType
*
p_out
,
const
EmbType
*
p_emb_b
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs
,
const
EmbType
*
p_emb_c
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexes
,
const
IndexType
*
p_index_a
,
const
GammaDataType
*
p_gamma
,
const
IndexType
*
p_index_b
,
const
BetaDataType
*
p_beta
,
const
IndexType
*
p_index_c
,
const
OutGridDesc
out_grid_desc
,
const
GammaDataType
*
p_gamma
,
const
AccDataType
epsilon
,
const
BetaDataType
*
p_beta
,
const
EmbElementwiseOperation
emb_elementwise_op
)
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
)
{
{
GridwiseSparseEmbedding
::
Run
(
p_out
,
GridwiseSparseEmbedding
::
Run
(
p_emb_a
,
p_out
,
p_embs
,
p_indexes
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
,
emb_elementwise_op
);
p_emb_b
,
p_emb_c
,
p_index_a
,
p_index_b
,
p_index_c
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
);
}
}
template
<
typename
EmbType
,
template
<
typename
EmbType
,
...
@@ -53,14 +44,16 @@ template <typename EmbType,
...
@@ -53,14 +44,16 @@ template <typename EmbType,
typename
AccDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutType
,
typename
OutGridDesc
,
typename
OutGridDesc
,
typename
EmbElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
DimPerBlock
,
// Row x Dim, along Dim
ck
::
index_t
DimPerBlock
,
// Row x Dim, along Dim
ck
::
index_t
RowPerBlock
,
// Row x Dim, along Row
ck
::
index_t
RowPerBlock
,
// Row x Dim, along Row
ck
::
index_t
DimThreadSize
,
// this is actually not vector, but number of registers
ck
::
index_t
DimThreadSize
,
// this is actually not vector, but number of registers
ck
::
index_t
RowVectorSize
>
ck
::
index_t
RowVectorSize
,
struct
GridwiseSparseEmbedding3ForwardLayernorm
ck
::
index_t
NumEmbeddings
>
struct
GridwiseSparseEmbeddingsForwardLayernorm
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -97,23 +90,17 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLength
,
Sequence
<
0
,
1
>>
;
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLength
,
Sequence
<
0
,
1
>>
;
__device__
static
void
Run
(
OutType
*
p_out
,
__device__
static
void
Run
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs
,
const
EmbType
*
p_emb_b
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexes
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
,
const
OutGridDesc
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
{
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr
auto
thread_cluster_desc
=
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
DimClusterSize
,
RowClusterSize
>
{},
Sequence
<
0
,
1
>
{});
make_cluster_descriptor
(
Sequence
<
DimClusterSize
,
RowClusterSize
>
{},
Sequence
<
0
,
1
>
{});
...
@@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -141,13 +128,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr
auto
gamma_beta_buf_desc
=
constexpr
auto
gamma_beta_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
RowSubBlocks
,
RowVectorSize
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
RowSubBlocks
,
RowVectorSize
));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_a
;
ck
::
Array
<
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_b
;
NumEmbeddings
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_c
;
in_thread_bufs
;
ck
::
Array
<
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexType
,
DimPerBlock
,
true
>
,
NumEmbeddings
>
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_a
;
index_bufs
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_c
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
acc_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
acc_thread_buf
;
...
@@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -160,42 +145,31 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
mean_var_buf_size
,
true
>
var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
mean_var_buf_size
,
true
>
var_thread_buf
;
auto
load_current_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
auto
load_current_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_a
;
ck
::
Array
<
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
,
NumEmbeddings
>
emb_vectors
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_b
;
auto
emb_a
=
emb_vectors
[
0
];
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_c
;
using
src_vector_t
=
typename
decltype
(
emb_a
)
::
type
;
using
src_vector_t
=
typename
decltype
(
emb_vector_a
)
::
type
;
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
constexpr
auto
current_dim
=
i_dim_sub_
*
DimPerSubBlock
+
i_dim_vec_
;
constexpr
auto
current_dim
=
i_dim_sub_
*
DimPerSubBlock
+
i_dim_vec_
;
IndexType
index_a
=
index_buf_a
[
Number
<
current_dim
>
{}];
IndexType
index_b
=
index_buf_b
[
Number
<
current_dim
>
{}];
IndexType
index_c
=
index_buf_c
[
Number
<
current_dim
>
{}];
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
EmbType
)
*
RowVectorSize
;
sizeof
(
EmbType
)
*
RowVectorSize
;
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
IndexType
index
=
index_bufs
[
i_embedding_
][
Number
<
current_dim
>
{}];
int32x4_t
emb_res_a
=
int32x4_t
emb_res
=
make_wave_buffer_resource_with_default_range
(
make_wave_buffer_resource_with_default_range
(
p_emb_a
+
index_a
*
RowPerBlock
);
p_embs
[
i_embedding_
]
+
index
*
RowPerBlock
);
int32x4_t
emb_res_b
=
emb_vectors
(
i_embedding_
).
template
AsType
<
src_vector_t
>()(
I0
)
=
make_wave_buffer_resource_with_default_range
(
p_emb_b
+
index_b
*
RowPerBlock
);
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res
,
thread_offset
,
0
);
int32x4_t
emb_res_c
=
});
make_wave_buffer_resource_with_default_range
(
p_emb_c
+
index_c
*
RowPerBlock
);
emb_vector_a
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_a
,
thread_offset
,
0
);
emb_vector_b
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_b
,
thread_offset
,
0
);
emb_vector_c
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_c
,
thread_offset
,
0
);
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
in_thread_buf_a
(
Number
<
register_offset
>
{})
=
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
emb_vector_a
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_bufs
(
i_embedding_
)(
Number
<
register_offset
>
{})
=
in_thread_buf_b
(
Number
<
register_offset
>
{})
=
ck
::
type_convert
<
AccDataType
>
(
emb_vector_b
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
emb_vectors
[
i_embedding_
].
template
AsType
<
EmbType
>()[
i_row_vec_
]);
in_thread_buf_c
(
Number
<
register_offset
>
{})
=
});
emb_vector_c
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
});
});
});
});
};
};
...
@@ -205,14 +179,15 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -205,14 +179,15 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
AccDataType
va
=
auto
in_data_refs
=
generate_tie
(
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_a
(
Number
<
register_offset
>
{}));
[
&
](
auto
i_embedding_
)
->
const
auto
&
{
AccDataType
vb
=
return
in_thread_bufs
(
i_embedding_
)(
Number
<
register_offset
>
{});
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_b
(
Number
<
register_offset
>
{}));
},
AccDataType
vc
=
Number
<
NumEmbeddings
>
{});
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_c
(
Number
<
register_offset
>
{}));
auto
out_data_refs
=
generate_tie
(
[
&
](
auto
)
->
auto
&
{
return
acc_thread_buf
(
Number
<
register_offset
>
{});
},
acc_thread_buf
(
Number
<
register_offset
>
{})
+=
va
+
vb
+
vc
;
Number
<
1
>
{});
unpack2
(
emb_elementwise_op
,
out_data_refs
,
in_data_refs
);
});
});
});
});
};
};
...
@@ -242,7 +217,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -242,7 +217,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr
auto
mean_var_offset
=
constexpr
auto
mean_var_offset
=
mean_var_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
));
mean_var_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
));
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
var_thread_buf
(
Number
<
mean_var_offset
>
{})
+
epsilon
);
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
...
@@ -250,9 +226,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -250,9 +226,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
gamma_beta_buf_desc
.
CalculateOffset
(
make_tuple
(
i_row_sub_
,
i_row_vec_
));
gamma_beta_buf_desc
.
CalculateOffset
(
make_tuple
(
i_row_sub_
,
i_row_vec_
));
auto
acc_val
=
acc_thread_buf
[
Number
<
register_offset
>
{}];
auto
acc_val
=
acc_thread_buf
[
Number
<
register_offset
>
{}];
acc_val
=
(
acc_val
-
mean_thread_buf
(
Number
<
mean_var_offset
>
{}))
/
acc_val
=
(
acc_val
-
mean_thread_buf
(
Number
<
mean_var_offset
>
{}))
*
divisor
;
sqrt
(
var_thread_buf
(
Number
<
mean_var_offset
>
{})
+
epsilon
);
acc_val
=
acc_val
*
gamma_thread_buf
[
Number
<
gamma_beta_offset
>
{}]
+
acc_val
=
acc_val
*
gamma_thread_buf
[
Number
<
gamma_beta_offset
>
{}]
+
beta_thread_buf
[
Number
<
gamma_beta_offset
>
{}];
beta_thread_buf
[
Number
<
gamma_beta_offset
>
{}];
out_vector
.
template
AsType
<
OutType
>()(
Number
<
i_row_vec_
>
{})
=
out_vector
.
template
AsType
<
OutType
>()(
Number
<
i_row_vec_
>
{})
=
...
@@ -273,9 +248,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -273,9 +248,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
// first load index
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
// prefer use s_load
// prefer use s_load
index_buf_a
(
i_idx_
)
=
p_index_a
[
index_start
+
i_idx_
.
value
];
ck
::
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
index_buf_b
(
i_idx_
)
=
p_index_b
[
index_start
+
i_idx_
.
value
];
index_bufs
(
i_embedding_
)(
i_idx_
)
=
index_buf_c
(
i_idx_
)
=
p_index_c
[
index_start
+
i_idx_
.
value
];
p_indexes
[
i_embedding_
][
index_start
+
i_idx_
.
value
];
});
});
});
// load gamma/beta
// load gamma/beta
...
@@ -329,7 +305,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
...
@@ -329,7 +305,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for
<
0
,
mean_var_buf_size
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
mean_var_buf_size
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
if
constexpr
(
I
>
0
)
block_sync_lds
();
block_sync_lds
();
BlockwiseWelford
::
Run
(
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
threadwise_welford
.
cur_count_
);
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
threadwise_welford
.
cur_count_
);
});
});
...
...
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
View file @
e7be2fe8
...
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
...
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
int
max_count_
;
int
max_count_
;
};
};
template
<
typename
T
,
typename
SrcMeanVarCountThreadDesc_M_K
,
typename
DstMeanVarThreadDesc_M
,
bool
GetActualVariance
=
false
>
struct
ThreadwiseWelfordMerge
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcMeanVarCountThreadDesc_M_K
{};
static
constexpr
auto
dst_thread_desc_m
=
DstMeanVarThreadDesc_M
{};
static
constexpr
auto
src_length_m
=
src_thread_desc_m_k
.
GetLength
(
Number
<
0
>
{});
static
constexpr
auto
src_length_k
=
src_thread_desc_m_k
.
GetLength
(
Number
<
1
>
{});
static
constexpr
auto
dst_length_m
=
dst_thread_desc_m
.
GetLength
(
Number
<
0
>
{});
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
__device__
static
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int32_t
&
count_a
,
T
mean_b
,
T
var_b
,
int32_t
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a
*
count_b_over_count
;
count_a
=
count
;
}
template
<
typename
SrcMeanBufferType
,
typename
SrcVarBufferType
,
typename
SrcCountBufferType
,
typename
DstMeanBufferType
,
typename
DstVarBufferType
,
typename
DstCountBufferType
>
__device__
static
void
Run
(
const
SrcMeanBufferType
&
src_mean_buf
,
const
SrcVarBufferType
&
src_var_buf
,
const
SrcCountBufferType
&
src_count_buf
,
DstMeanBufferType
&
dst_mean_buf
,
DstVarBufferType
&
dst_var_buf
,
DstCountBufferType
&
dst_count_buf
)
{
static_for
<
0
,
src_length_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
src_length_k
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
src_offset
=
src_thread_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
Merge
(
dst_mean_buf
(
iM
),
dst_var_buf
(
iM
),
dst_count_buf
(
iM
),
src_mean_buf
[
Number
<
src_offset
>
{}],
src_var_buf
[
Number
<
src_offset
>
{}],
src_count_buf
[
Number
<
src_offset
>
{}]);
});
if
constexpr
(
GetActualVariance
)
{
dst_var_buf
(
iM
)
=
dst_var_buf
[
iM
]
/
dst_count_buf
[
iM
];
};
});
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace
ck
{
enum
struct
WmmaInstr
{
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
};
/*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template
<
WmmaInstr
Instr
,
index_t
WaveSize
,
typename
=
void
>
struct
wmma_type
{
};
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f32_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f32_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f16_16x16x16_f16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
Opsel
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f16_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f16_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_bf16_16x16x16_bf16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
Opsel
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_bf16_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
,
Opsel
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
#endif
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_i32_16x16x16_iu8_w32
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_i32_16x16x16_iu8_w64
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
struct
WmmaSelector
{
template
<
typename
src_type_a_
,
typename
src_type_b_
,
typename
dst_type_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
static
constexpr
auto
GetWmma
();
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
#endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static
constexpr
auto
selected_wmma
=
wmma_type
<
GetWmma
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
(),
Number
<
32
>
{}
>
{};
__host__
__device__
constexpr
WmmaSelector
()
{
static_assert
(
selected_wmma
.
m_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
m_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Invalid Number of Accumulator Register"
);
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
KPack
,
bool
TransposeC
=
false
>
struct
WmmaGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__host__
__device__
constexpr
WmmaGemm
()
{
static_assert
(
NPerWmma
==
16
&&
MPerWmma
==
16
,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"
);
static_assert
(
KPack
==
wmma_instr
.
k_per_wmma
,
"KPack should be k_per_wmma"
);
}
// WMMA output supporting C = A * B
// Vector Write
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{})),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
6
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
half_t
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
bhalf_t
>::
value
)
||
(
is_same
<
src_type_a
,
int8_t
>::
value
&&
is_same
<
src_type_b
,
int8_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
(
is_same
<
src_type_a
,
int4_t
>::
value
&&
is_same
<
src_type_b
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
else
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
GetSubGroupId
()
{
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
__device__
static
auto
GetLaneIdUnderSubGroup
()
{
return
GetLaneId
()
%
wmma_instr
.
num_thread_per_subgroups
;
}
__device__
static
auto
GetSwizzledLaneIdLow
()
{
return
((
GetLaneIdUnderSubGroup
()
&
1
)
<<
3
)
|
(
GetLaneIdUnderSubGroup
()
>>
1
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
return
GetSwizzledLaneIdLow
();
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
return
GetLaneIdUnderSubGroup
();
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
index_t
n_offset
=
GetLaneIdUnderSubGroup
();
index_t
m_offset
=
GetSubGroupId
()
*
wmma_instr
.
num_acc_vgprs_per_wave
;
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
e7be2fe8
...
@@ -593,7 +593,8 @@ struct XdlopsGemm
...
@@ -593,7 +593,8 @@ struct XdlopsGemm
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_instr
.
num_output_blks
;
}
...
@@ -822,6 +823,16 @@ struct XdlopsGemm
...
@@ -822,6 +823,16 @@ struct XdlopsGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
static
auto
MakeGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
gs_ms_ns_strides_vec
)
{
if
(
!
(
gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
))
{
throw
std
::
runtime_error
(
"wrong! dimension must match input lengths"
);
}
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
gs_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
gs_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
mDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
TensorSpec
==
device
::
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
const
auto
grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
gs_ms_ns_lengths
,
gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const
auto
grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
c_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
const
auto
grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
-
Number
<
NumDimG
>
{},
nDimIds
-
Number
<
NumDimG
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
}
template
<
typename
NumDims_G_M_N_K_O
,
// Sequence<>
typename
PerBlock_M_N_K_O
,
// Sequence<>
device
::
GemmSpecialization
GemmSpec
,
device
::
TensorSpecialization
ASpec
,
device
::
TensorSpecialization
B0Spec
,
device
::
TensorSpecialization
B1Spec
,
device
::
TensorSpecialization
CSpec
>
struct
TransformBatchedContractionContractionToBatchedGemmGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
index_t
NumDimG
=
NumDims_G_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NumDimM
=
NumDims_G_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
NumDimN
=
NumDims_G_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
NumDimK
=
NumDims_G_M_N_K_O
::
At
(
I3
);
static
constexpr
index_t
NumDimO
=
NumDims_G_M_N_K_O
::
At
(
I4
);
static
constexpr
index_t
MPerBlock
=
PerBlock_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NPerBlock
=
PerBlock_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
KPerBlock
=
PerBlock_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
OPerBlock
=
PerBlock_M_N_K_O
::
At
(
I3
);
static
constexpr
auto
matrix_padder
=
device
::
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
OPerBlock
};
//
// A
//
static
auto
MakeAGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
}
// TODO: rename to G_MRaw_KRaw
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
AK1
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// B (alias of B0)
//
static
auto
MakeB0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
second
);
}
template
<
typename
BGridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
Number
&
BK1
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// B1
//
static
auto
MakeB1GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
}
// TODO: rename to G_NRaw_KRaw
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
second
);
}
template
<
typename
B1GridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
Number
&
B1K1
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// C
//
static
auto
MakeCGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
e7be2fe8
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm
...
@@ -47,10 +48,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
index_t
NWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
const
auto
in_gemmm_gemmk_desc
=
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NWo
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
NWo
,
C
));
...
@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm
...
@@ -146,10 +146,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NHoWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
const
auto
in_gemmm_gemmk_desc
=
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
C
));
...
@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm
...
@@ -262,10 +261,8 @@ struct TransformConvFwdToGemm
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
index_t
NDoHoWo
=
const
index_t
NDoHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
in_gemmm_gemmk_desc
=
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo
,
C
));
...
@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm
...
@@ -390,10 +387,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NHoWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
// This is different
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
...
@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm
...
@@ -506,10 +502,9 @@ struct TransformConvFwdToGemm
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NHoWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
// This is different
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
...
@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm
...
@@ -639,10 +634,8 @@ struct TransformConvFwdToGemm
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
const
index_t
NDoHoWo
=
const
index_t
NDoHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
// This is different
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
...
@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm
...
@@ -768,10 +761,8 @@ struct TransformConvFwdToGemm
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
wei_gemmn_gemmk_desc
=
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
...
@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm
...
@@ -794,10 +785,8 @@ struct TransformConvFwdToGemm
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
std
::
accumulate
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
...
@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm
...
@@ -827,10 +816,9 @@ struct TransformConvFwdToGemm
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NHoWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
...
@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm
...
@@ -855,10 +843,9 @@ struct TransformConvFwdToGemm
const
auto
KStride
=
I1
;
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NHoWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmm_gemmn_desc
=
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
...
@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm
...
@@ -878,10 +865,9 @@ struct TransformConvFwdToGemm
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
const
index_t
NHoWo
=
c_g_n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
N
*
ck
::
accumulate_n
<
index_t
>
(
index_t
{
1
},
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmm_gemmn_desc
=
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
I1
));
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
e7be2fe8
...
@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
...
@@ -355,5 +355,11 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3
);
c3
);
}
}
// Ranged input operand
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
{
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/amd_wmma.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp"
// TODO: Add arch limitation
namespace
ck
{
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
;
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half16_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: bf16, dst: bf16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
;
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf16_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
/********************************WAVE64 MODE***********************************************/
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: fp16, dst: fp16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w64
;
template
<
index_t
Opsel
>
struct
intrin_wmma_f16_16x16x16_f16_w64
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
half8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f16_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: bf16, dst: bf16
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w64
;
template
<
index_t
Opsel
>
struct
intrin_wmma_bf16_16x16x16_bf16_w64
<
16
,
16
,
Opsel
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// opsel usage
// false: D0.[0:15] = result
// true : D0.[16:31]= result
reg_c
.
template
AsType
<
bhalf8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}],
Opsel
);
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w64
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w64
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64
(
neg_a
,
bit_cast
<
int32x4_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x4_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
clamp
);
}
};
}
// namespace ck
#endif
include/ck/utility/amd_xdlops.hpp
View file @
e7be2fe8
...
@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
...
@@ -254,7 +254,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf2_t
&
reg_a
,
const
bhalf2_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
bhalf2_t
&
reg_a
,
const
bhalf2_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_
32x32x4
bf16
(
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_
16x16x8
bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
}
}
};
};
...
...
include/ck/utility/math_v2.hpp
View file @
e7be2fe8
...
@@ -3,7 +3,9 @@
...
@@ -3,7 +3,9 @@
#pragma once
#pragma once
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#include <cmath>
#endif
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type.hpp"
...
@@ -114,7 +116,16 @@ static inline __device__ int4_t abs(int4_t x)
...
@@ -114,7 +116,16 @@ static inline __device__ int4_t abs(int4_t x)
};
};
#endif
#endif
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
half_t
abs
(
half_t
x
)
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
...
@@ -140,11 +151,16 @@ static inline __device__ bool isnan(int4_t x)
...
@@ -140,11 +151,16 @@ static inline __device__ bool isnan(int4_t x)
};
};
#endif
#endif
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
bool
isnan
(
half_t
x
)
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_
sqrtf
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
::
sqrt
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_
sqrt
(
x
);
};
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
include/ck/utility/reduction_operator.hpp
View file @
e7be2fe8
...
@@ -251,27 +251,27 @@ constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum o
...
@@ -251,27 +251,27 @@ constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum o
};
};
template
<
InMemoryDataOperationEnum
Operation
,
typename
DataType
>
template
<
InMemoryDataOperationEnum
Operation
,
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
struct
InMemoryDataOperat
i
onSupportedOnDataType
{
{
static
constexpr
bool
value
=
false
;
static
constexpr
bool
value
=
false
;
};
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
{
{
static
constexpr
bool
value
=
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
{
{
static
constexpr
bool
value
=
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
{
{
static
constexpr
bool
value
=
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
...
@@ -280,7 +280,7 @@ struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, D
...
@@ -280,7 +280,7 @@ struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, D
};
};
template
<
typename
DataType
>
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
struct
InMemoryDataOperat
i
onSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
{
{
static
constexpr
bool
value
=
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
...
...
include/ck/utility/synchronization.hpp
View file @
e7be2fe8
...
@@ -18,6 +18,7 @@ __device__ void block_sync_lds()
...
@@ -18,6 +18,7 @@ __device__ void block_sync_lds()
__syncthreads
();
__syncthreads
();
#endif
#endif
}
}
__device__
void
s_nop
()
__device__
void
s_nop
()
{
{
#if 1
#if 1
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_backward.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
XDataType
,
typename
DxDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormBwd
:
public
device
::
DeviceBatchNormBwd
<
XDataType
,
DxDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
DyDataType
*
p_dy
,
const
ScaleDataType
*
p_scale
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
p_dx
,
DscaleDbiasDataType
*
p_dscale
,
DscaleDbiasDataType
*
p_dbias
)
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnDscaleDbiasStrides_
(
bnDscaleDbiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_dy_
(
p_dy
),
p_scale_
(
p_scale
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
{
using
ck
::
host_common
::
get_index_set
;
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
// get invariant_dims[] and invariant_lengths[]
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
invariantDims_
[
i
]
=
dim
;
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
i
++
;
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
dy_invariant_strides_
[
i
]
=
dyStrides
[
dim
];
dx_invariant_strides_
[
i
]
=
dxStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
dy_reduce_strides_
[
i
]
=
dyStrides
[
dim
];
dx_reduce_strides_
[
i
]
=
dxStrides
[
dim
];
i
++
;
};
reduceSize_
=
std
::
accumulate
(
reduce_lengths_
.
begin
(),
reduce_lengths_
.
end
(),
1
,
std
::
multiplies
<
size_t
>
{});
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
haveSavedMeanInvVar_
=
(
p_savedMean
!=
nullptr
&&
p_savedInvVar
!=
nullptr
);
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
dy_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
dx_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
dy_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
dx_reduce_strides_
;
const
XDataType
*
p_x_
;
const
DyDataType
*
p_dy_
;
const
ScaleDataType
*
p_scale_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
DscaleDbiasDataType
*
p_dscale_
;
DscaleDbiasDataType
*
p_dbias_
;
bool
haveSavedMeanInvVar_
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
epsilon_
;
size_t
reduceSize_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
using
ck
::
host_common
::
get_offset_from_index
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
dy_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
dy_invariant_strides_
,
invariant_index
);
size_t
dx_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
dx_invariant_strides_
,
invariant_index
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
invVar
;
int32_t
curr_count
=
0
;
if
(
arg
.
haveSavedMeanInvVar_
)
{
size_t
mean_invVar_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
mean
=
type_convert
<
AccDataType
>
(
arg
.
p_savedMean_
[
mean_invVar_invariant_offset
]);
invVar
=
type_convert
<
AccDataType
>
(
arg
.
p_savedInvVar_
[
mean_invVar_invariant_offset
]);
}
else
{
// compute mean, variance using welford method
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
delta
=
x
-
mean
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
};
// actual variance
variance
=
variance
/
curr_count
;
// inv-variance defined as 1/sqrt(epsilon+variance)
invVar
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
};
AccDataType
dbias
=
type_convert
<
AccDataType
>
(
0.0
f
);
// Sum on reduced dimensions of dy
AccDataType
dscale
=
type_convert
<
AccDataType
>
(
0.0
f
);
// Sum on reduced dimensions of dy * norm_x
// 1) calculate dy * (x - mean) * inv-variance
// 2) calculate sum(dy) on reduced dimensions
// 3) calculate sum(dy * norm_x) on reduced dimensions
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
dy_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
dy_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
dy_offset
=
dy_invariant_offset
+
dy_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
dy_offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
dbias
+=
dy
;
dscale
+=
norm_x
*
dy
;
};
size_t
dscale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnDscaleDbiasStrides_
,
invariant_index
);
size_t
dbias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnDscaleDbiasStrides_
,
invariant_index
);
arg
.
p_dscale_
[
dscale_offset
]
=
type_convert
<
DscaleDbiasDataType
>
(
dscale
);
arg
.
p_dbias_
[
dbias_offset
]
=
type_convert
<
DscaleDbiasDataType
>
(
dbias
);
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
p_scale_
[
scale_offset
]);
AccDataType
multiplier
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
type_convert
<
AccDataType
>
(
arg
.
reduceSize_
)
*
invVar
*
scale
;
// 1) calculate tmp = dscale * (x - mean) * inv-variance
// 2) calculate dx = 1/reduceSize * inv-variance * scale * (reduceSize * dy - dbias
// - tmp)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
dy_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
dy_reduce_strides_
,
reduce_index
);
size_t
dx_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
dx_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
dy_offset
=
dy_invariant_offset
+
dy_reduce_offset
;
auto
dx_offset
=
dx_invariant_offset
+
dx_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVar
;
AccDataType
dy
=
type_convert
<
AccDataType
>
(
arg
.
p_dy_
[
dy_offset
]);
arg
.
dy_elementwise_op_
(
dy
,
dy
);
AccDataType
tmpVal
=
norm_x
*
dscale
;
AccDataType
dx
=
multiplier
*
(
type_convert
<
AccDataType
>
(
arg
.
reduceSize_
)
*
dy
-
dbias
-
tmpVal
);
arg
.
p_dx_
[
dx_offset
]
=
type_convert
<
DxDataType
>
(
dx
);
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]);
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
return
(
0.0
f
);
};
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
};
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
(
void
)
p_arg
;
return
(
true
);
};
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_dy
,
const
void
*
p_scale
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dbias
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
dxStrides
,
dyStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnDscaleDbiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
DyDataType
*>
(
p_dy
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
epsilon
,
dy_elementwise_op
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dscale
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dbias
));
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_BatchNorm_Backward"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment