Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
925a8d78
Commit
925a8d78
authored
Dec 19, 2021
by
Chao Liu
Browse files
refactor
parent
681ede91
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
303 additions
and
415 deletions
+303
-415
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp
+5
-5
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r1.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v6r1.hpp
+5
-5
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
...tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
+5
-5
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+98
-145
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
+102
-149
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
...xdl_output_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+39
-48
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
...e/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
+39
-48
example/4_conv2d_fwd_xdl_output_shuffle/conv2d_fwd_xdl_output_shuffle.cpp
..._fwd_xdl_output_shuffle/conv2d_fwd_xdl_output_shuffle.cpp
+5
-5
example/6_conv2d_fwd_xdl_output_shuffle_bias_relu_add/conv2d_fwd_xdl_output_shuffle_bias_relu_add.cpp
..._relu_add/conv2d_fwd_xdl_output_shuffle_bias_relu_add.cpp
+5
-5
No files found.
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v4r1.hpp
View file @
925a8d78
...
...
@@ -18,7 +18,6 @@ template <index_t BlockSize,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
...
...
@@ -39,6 +38,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v4r1
(
...
...
@@ -58,14 +59,13 @@ struct BlockwiseTensorSliceTransfer_v4r1
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
SrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
T
hread
S
lice
L
engths
{}
*
ThreadClusterLengths
{})
>
{},
is_same
<
BlockSliceLengths
,
decltype
(
t
hread
_s
lice
_l
engths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
...
...
@@ -77,7 +77,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
T
hread
S
lice
L
engths
{}
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
t
hread
_s
lice
_l
engths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
...
...
@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceTransfer_v4r1
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1
<
T
hread
S
lice
L
engths
,
ThreadwiseTensorSliceTransfer_v3r1
<
decltype
(
t
hread
_s
lice
_l
engths
)
,
SrcElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r1.hpp
View file @
925a8d78
...
...
@@ -17,7 +17,6 @@ template <index_t BlockSize,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
...
...
@@ -33,6 +32,8 @@ struct BlockwiseTensorSliceTransfer_v6r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
...
...
@@ -49,14 +50,13 @@ struct BlockwiseTensorSliceTransfer_v6r1
{
static_assert
(
nDim
==
remove_reference_t
<
remove_cv_t
<
SrcDesc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
T
hread
S
lice
L
engths
{}
*
ThreadClusterLengths
{})
>
{},
is_same
<
BlockSliceLengths
,
decltype
(
t
hread
_s
lice
_l
engths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
...
...
@@ -68,7 +68,7 @@ struct BlockwiseTensorSliceTransfer_v6r1
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
T
hread
S
lice
L
engths
{}
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
t
hread
_s
lice
_l
engths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
...
...
@@ -118,7 +118,7 @@ struct BlockwiseTensorSliceTransfer_v6r1
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
T
hread
S
lice
L
engths
,
decltype
(
t
hread
_s
lice
_l
engths
)
,
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v6r3.hpp
View file @
925a8d78
...
...
@@ -17,7 +17,6 @@ template <index_t BlockSize,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum_t
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
...
...
@@ -39,6 +38,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
...
...
@@ -65,14 +66,13 @@ struct BlockwiseTensorSliceTransfer_v6r3
nDim
==
remove_reference_t
<
remove_cv_t
<
Src1Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
Src2Desc
>>::
GetNumOfDimension
()
&&
nDim
==
remove_reference_t
<
remove_cv_t
<
DstDesc
>>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
T
hread
S
lice
L
engths
{}
*
ThreadClusterLengths
{})
>
{},
is_same
<
BlockSliceLengths
,
decltype
(
t
hread
_s
lice
_l
engths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
...
...
@@ -84,7 +84,7 @@ struct BlockwiseTensorSliceTransfer_v6r3
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
T
hread
S
lice
L
engths
{}
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
t
hread
_s
lice
_l
engths
;
threadwise_transfer_
.
SetSrc0SliceOrigin
(
src0_desc
,
src0_block_slice_origin
+
thread_data_idx_begin
);
...
...
@@ -165,7 +165,7 @@ struct BlockwiseTensorSliceTransfer_v6r3
Src2Desc
,
DstDesc
,
ElementwiseOperation
,
T
hread
S
lice
L
engths
,
decltype
(
t
hread
_s
lice
_l
engths
)
,
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
925a8d78
...
...
@@ -56,50 +56,46 @@ __global__ void
block_2_ctile_map
);
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
MRepeatPerShuffle_CCopy
,
index_t
NRepeatPerShuffle_CCopy
,
index_t
MRepeatThread_CCopy
,
index_t
MThread_CCopy
,
index_t
NRepeatThread_CCopy
,
index_t
NThread_CCopy
,
index_t
NScalarPerVector_CCopy
>
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -363,7 +359,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -394,7 +389,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -500,54 +494,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
// shuffle and write out
// shuffle
C
and write out
{
#if 0
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 1;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 32;
constexpr index_t NRepeatThread_CCopy = 1;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#elif
0
// TODO: make it tunable
constexpr
index_t
MRepeatPerShuffle_CCopy
=
1
;
constexpr
index_t
NRepeatPerShuffle_CCopy
=
2
;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr
index_t
MRepeatThread_CCopy
=
1
;
constexpr
index_t
MThread_CCopy
=
16
;
constexpr
index_t
NRepeatThread_CCopy
=
2
;
constexpr
index_t
NThread_CCopy
=
8
;
// vector length for blockwise copy from LDS to global
constexpr
index_t
NScalarPerVector_CCopy
=
8
;
#endif
static_assert
(
MRepeat
%
MRepeatPerShuffle_CCopy
==
0
&&
NRepeat
%
NRepeatPerShuffle_CCopy
==
0
,
static_assert
(
MRepeat
%
CShuffleMRepeatPerShuffle
==
0
&&
NRepeat
%
CShuffleNRepeatPerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXdl
);
constexpr
index_t
MPerBlock_CCopy
=
MWave
*
MPerXdl
;
constexpr
index_t
NPerBlock_CCopy
=
NWave
*
NPerXdl
;
constexpr
index_t
MPerThread_CCopy
=
MPerBlock_CCopy
/
MThread_CCopy
;
constexpr
index_t
NPerThread_CCopy
=
NPerBlock_CCopy
/
NThread_CCopy
;
constexpr
index_t
MRepeatPerThread_CCopy
=
MRepeatPerShuffle_CCopy
/
MRepeatThread_CCopy
;
constexpr
index_t
NRepeatPerThread_CCopy
=
NRepeatPerShuffle_CCopy
/
NRepeatThread_CCopy
;
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
...
@@ -568,10 +523,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeatPerShuffle
_CCopy
>
{},
Number
<
CShuffle
MRepeatPerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
NRepeatPerShuffle
_CCopy
>
{},
Number
<
CShuffle
NRepeatPerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
...
@@ -583,12 +538,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
MRepeatPerShuffle
_CCopy
>
{}),
// M0 (MRepeat) per shuffle
Number
<
CShuffle
MRepeatPerShuffle
>
{}),
// M0 (MRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
NRepeatPerShuffle
_CCopy
>
{}),
// N0 (NRepeat) per shuffle
Number
<
CShuffle
NRepeatPerShuffle
>
{}),
// N0 (NRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -635,61 +590,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
MRepeatPerShuffle_CCopy
,
NRepeatPerShuffle_CCopy
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_block_copy_lds_to_global
=
BlockwiseTensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
MRepeatPerShuffle_CCopy
,
MPerBlock_CCopy
,
1
,
NRepeatPerShuffle_CCopy
,
NPerBlock_CCopy
>
,
// BlockSliceLengths,
Sequence
<
1
,
MRepeatPerShuffle_CCopy
,
MPerThread_CCopy
,
1
,
NRepeatPerShuffle_CCopy
,
NPerThread_CCopy
>
,
// ThreadSliceLengths,
Sequence
<
1
,
MRepeatPerThread_CCopy
,
MThread_CCopy
,
CShuffleMRepeatPerShuffle
,
MWave
*
MPerXdl
,
1
,
NRepeatPerThread_CCopy
,
NThread_CCopy
>
,
// ThreadClusterLengths,
CShuffleNRepeatPerShuffle
,
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
N
ScalarPerVector_
CCopy
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
CBlockTransfer
ScalarPerVector_
NWaveNPerXdl
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
...
...
@@ -697,22 +649,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
c_element_op
};
constexpr
auto
mrepeat_forward_step
=
make_multi_index
(
0
,
MRepeatPerShuffle
_CCopy
,
0
,
0
,
0
,
0
);
make_multi_index
(
0
,
CShuffle
MRepeatPerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
nrepeat_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
NRepeatPerShuffle
_CCopy
,
0
);
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffle
NRepeatPerShuffle
,
0
);
constexpr
auto
nrepeat_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
NRepeatPerShuffle
_CCopy
,
0
);
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffle
NRepeatPerShuffle
,
0
);
static_for
<
0
,
MRepeat
,
MRepeatPerShuffle
_CCopy
>
{}([
&
](
auto
mrepeat_iter
)
{
static_for
<
0
,
MRepeat
,
CShuffle
MRepeatPerShuffle
>
{}([
&
](
auto
mrepeat_iter
)
{
constexpr
auto
mrepeat
=
mrepeat_iter
;
static_for
<
0
,
NRepeat
,
NRepeatPerShuffle
_CCopy
>
{}([
&
](
auto
nrepeat_iter
)
{
static_for
<
0
,
NRepeat
,
CShuffle
NRepeatPerShuffle
>
{}([
&
](
auto
nrepeat_iter
)
{
constexpr
bool
nrepeat_forward_sweep
=
(
mrepeat
%
(
2
*
MRepeatPerShuffle
_CCopy
)
==
0
);
(
mrepeat
%
(
2
*
CShuffle
MRepeatPerShuffle
)
==
0
);
constexpr
index_t
nrepeat_value
=
nrepeat_forward_sweep
?
nrepeat_iter
:
(
NRepeat
-
nrepeat_iter
-
NRepeatPerShuffle_CCopy
);
nrepeat_forward_sweep
?
nrepeat_iter
:
(
NRepeat
-
nrepeat_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nrepeat
=
Number
<
nrepeat_value
>
{};
...
...
@@ -739,7 +692,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// move on nrepeat dimension
if
constexpr
(
nrepeat_forward_sweep
&&
(
nrepeat
<
NRepeat
-
NRepeatPerShuffle
_CCopy
))
(
nrepeat
<
NRepeat
-
CShuffle
NRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
...
...
@@ -754,7 +707,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
});
// move on mrepeat dimension
if
constexpr
(
mrepeat
<
MRepeat
-
MRepeatPerShuffle
_CCopy
)
if
constexpr
(
mrepeat
<
MRepeat
-
CShuffle
MRepeatPerShuffle
)
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r3.hpp
View file @
925a8d78
...
...
@@ -68,52 +68,48 @@ __global__ void
block_2_ctile_map
);
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
MRepeatPerShuffle_CCopy
,
index_t
NRepeatPerShuffle_CCopy
,
index_t
MRepeatThread_CCopy
,
index_t
MThread_CCopy
,
index_t
NRepeatThread_CCopy
,
index_t
NThread_CCopy
,
index_t
NScalarPerVector_CCopy
>
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -402,7 +398,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -433,7 +428,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum_t
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -539,54 +533,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
// shuffle and write out
// shuffle
C
and write out
{
#if 0
// TODO: make it tunable
constexpr index_t MRepeatPerShuffle_CCopy = 1;
constexpr index_t NRepeatPerShuffle_CCopy = 1;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr index_t MRepeatThread_CCopy = 1;
constexpr index_t MThread_CCopy = 32;
constexpr index_t NRepeatThread_CCopy = 1;
constexpr index_t NThread_CCopy = 8;
// vector length for blockwise copy from LDS to global
constexpr index_t NScalarPerVector_CCopy = 8;
#elif
0
// TODO: make it tunable
constexpr
index_t
MRepeatPerShuffle_CCopy
=
1
;
constexpr
index_t
NRepeatPerShuffle_CCopy
=
2
;
// TODO: this is hardcoded, only works for BlockSize = 256. fix it!
constexpr
index_t
MRepeatThread_CCopy
=
1
;
constexpr
index_t
MThread_CCopy
=
16
;
constexpr
index_t
NRepeatThread_CCopy
=
2
;
constexpr
index_t
NThread_CCopy
=
8
;
// vector length for blockwise copy from LDS to global
constexpr
index_t
NScalarPerVector_CCopy
=
8
;
#endif
static_assert
(
MRepeat
%
MRepeatPerShuffle_CCopy
==
0
&&
NRepeat
%
NRepeatPerShuffle_CCopy
==
0
,
static_assert
(
MRepeat
%
CShuffleMRepeatPerShuffle
==
0
&&
NRepeat
%
CShuffleNRepeatPerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXdl
);
constexpr
index_t
MPerBlock_CCopy
=
MWave
*
MPerXdl
;
constexpr
index_t
NPerBlock_CCopy
=
NWave
*
NPerXdl
;
constexpr
index_t
MPerThread_CCopy
=
MPerBlock_CCopy
/
MThread_CCopy
;
constexpr
index_t
NPerThread_CCopy
=
NPerBlock_CCopy
/
NThread_CCopy
;
constexpr
index_t
MRepeatPerThread_CCopy
=
MRepeatPerShuffle_CCopy
/
MRepeatThread_CCopy
;
constexpr
index_t
NRepeatPerThread_CCopy
=
NRepeatPerShuffle_CCopy
/
NRepeatThread_CCopy
;
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
...
@@ -607,10 +562,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeatPerShuffle
_CCopy
>
{},
Number
<
CShuffle
MRepeatPerShuffle
>
{},
Number
<
MWave
*
MPerXdl
>
{},
I1
,
Number
<
NRepeatPerShuffle
_CCopy
>
{},
Number
<
CShuffle
NRepeatPerShuffle
>
{},
Number
<
NWave
*
NPerXdl
>
{}));
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
...
...
@@ -622,12 +577,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
Number
<
MRepeatPerShuffle
_CCopy
>
{}),
// M0 (MRepeat) per shuffle
Number
<
CShuffle
MRepeatPerShuffle
>
{}),
// M0 (MRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
Number
<
NRepeatPerShuffle
_CCopy
>
{}),
// N0 (NRepeat) per shuffle
Number
<
CShuffle
NRepeatPerShuffle
>
{}),
// N0 (NRepeat) per shuffle
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -674,51 +629,48 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
MRepeatPerShuffle_CCopy
,
NRepeatPerShuffle_CCopy
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum_t
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_block_copy_lds_to_global
=
BlockwiseTensorSliceTransfer_v6r3
<
BlockSize
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
MRepeatPerShuffle_CCopy
,
MPerBlock_CCopy
,
1
,
NRepeatPerShuffle_CCopy
,
NPerBlock_CCopy
>
,
// BlockSliceLengths,
Sequence
<
1
,
MRepeatPerShuffle_CCopy
,
MPerThread_CCopy
,
1
,
NRepeatPerShuffle_CCopy
,
NPerThread_CCopy
>
,
// ThreadSliceLengths,
Sequence
<
1
,
MRepeatPerThread_CCopy
,
MThread_CCopy
,
CShuffleMRepeatPerShuffle
,
MWave
*
MPerXdl
,
1
,
NRepeatPerThread_CCopy
,
NThread_CCopy
>
,
// ThreadClusterLengths,
CShuffleNRepeatPerShuffle
,
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename Src0Data,
FloatC
,
// typename Src1Data,
...
...
@@ -728,13 +680,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
decltype
(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c1_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
decltype
(
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
),
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
N
ScalarPerVector_
CCopy
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// typename DimAccessOrder,
5
,
// index_t VectorDim,
CBlockTransfer
ScalarPerVector_
NWaveNPerXdl
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_block_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
...
...
@@ -746,22 +698,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_element_op
};
constexpr
auto
mrepeat_forward_step
=
make_multi_index
(
0
,
MRepeatPerShuffle
_CCopy
,
0
,
0
,
0
,
0
);
make_multi_index
(
0
,
CShuffle
MRepeatPerShuffle
,
0
,
0
,
0
,
0
);
constexpr
auto
nrepeat_forward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
NRepeatPerShuffle
_CCopy
,
0
);
make_multi_index
(
0
,
0
,
0
,
0
,
CShuffle
NRepeatPerShuffle
,
0
);
constexpr
auto
nrepeat_backward_step
=
make_multi_index
(
0
,
0
,
0
,
0
,
-
NRepeatPerShuffle
_CCopy
,
0
);
make_multi_index
(
0
,
0
,
0
,
0
,
-
CShuffle
NRepeatPerShuffle
,
0
);
static_for
<
0
,
MRepeat
,
MRepeatPerShuffle
_CCopy
>
{}([
&
](
auto
mrepeat_iter
)
{
static_for
<
0
,
MRepeat
,
CShuffle
MRepeatPerShuffle
>
{}([
&
](
auto
mrepeat_iter
)
{
constexpr
auto
mrepeat
=
mrepeat_iter
;
static_for
<
0
,
NRepeat
,
NRepeatPerShuffle
_CCopy
>
{}([
&
](
auto
nrepeat_iter
)
{
static_for
<
0
,
NRepeat
,
CShuffle
NRepeatPerShuffle
>
{}([
&
](
auto
nrepeat_iter
)
{
constexpr
bool
nrepeat_forward_sweep
=
(
mrepeat
%
(
2
*
MRepeatPerShuffle
_CCopy
)
==
0
);
(
mrepeat
%
(
2
*
CShuffle
MRepeatPerShuffle
)
==
0
);
constexpr
index_t
nrepeat_value
=
nrepeat_forward_sweep
?
nrepeat_iter
:
(
NRepeat
-
nrepeat_iter
-
NRepeatPerShuffle_CCopy
);
nrepeat_forward_sweep
?
nrepeat_iter
:
(
NRepeat
-
nrepeat_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nrepeat
=
Number
<
nrepeat_value
>
{};
...
...
@@ -792,7 +745,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// move on nrepeat dimension
if
constexpr
(
nrepeat_forward_sweep
&&
(
nrepeat
<
NRepeat
-
NRepeatPerShuffle
_CCopy
))
(
nrepeat
<
NRepeat
-
CShuffle
NRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
...
...
@@ -823,7 +776,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
});
// move on mrepeat dimension
if
constexpr
(
mrepeat
<
MRepeat
-
MRepeatPerShuffle
_CCopy
)
if
constexpr
(
mrepeat
<
MRepeat
-
CShuffle
MRepeatPerShuffle
)
{
c_block_copy_lds_to_global
.
MoveSrc1SliceWindow
(
c0_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl
,
...
...
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
925a8d78
...
...
@@ -18,45 +18,41 @@ namespace device {
// out[N, Ho, Wo, K] =
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K]
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
MRepeatPerShuffle_CCopy
,
index_t
NRepeatPerShuffle_CCopy
,
index_t
MRepeatThread_CCopy
,
index_t
MThread_CCopy
,
index_t
NRepeatThread_CCopy
,
index_t
NThread_CCopy
,
index_t
NScalarPerVector_CCopy
>
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dFwdXdl_Output_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
...
...
@@ -257,7 +253,6 @@ struct
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder,
Sequence
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder,
...
...
@@ -266,7 +261,6 @@ struct
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder,
Sequence
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder,
...
...
@@ -275,13 +269,10 @@ struct
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
MRepeatPerShuffle_CCopy
,
NRepeatPerShuffle_CCopy
,
MRepeatThread_CCopy
,
MThread_CCopy
,
NRepeatThread_CCopy
,
NThread_CCopy
,
NScalarPerVector_CCopy
>
;
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
View file @
925a8d78
...
...
@@ -17,45 +17,41 @@ namespace tensor_operation {
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadSliceLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
MRepeatPerShuffle_CCopy
,
index_t
NRepeatPerShuffle_CCopy
,
index_t
MRepeatThread_CCopy
,
index_t
MThread_CCopy
,
index_t
NRepeatThread_CCopy
,
index_t
NThread_CCopy
,
index_t
NScalarPerVector_CCopy
>
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvFwd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
...
...
@@ -238,7 +234,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder,
Sequence
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder,
...
...
@@ -247,7 +242,6 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder,
Sequence
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder,
...
...
@@ -256,13 +250,10 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
MRepeatPerShuffle_CCopy
,
NRepeatPerShuffle_CCopy
,
MRepeatThread_CCopy
,
MThread_CCopy
,
NRepeatThread_CCopy
,
NThread_CCopy
,
NScalarPerVector_CCopy
>
;
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferClusterLengths_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
example/4_conv2d_fwd_xdl_output_shuffle/conv2d_fwd_xdl_output_shuffle.cpp
View file @
925a8d78
...
...
@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough_v2;
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// clang-format off
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer|
ABlockTransfer|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer|
BBlockTransfer|
BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer|
B
Block
Lds| MRepeatPer| NRepeatPer| MRepeat| MThread| NRepeat| NThread| NScalarP
er|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per|
ThreadSlice|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
ThreadSlice|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
Shuffle| Shuffle| Thread| _CCopy| Thread| _CCopy|
Vector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave|
Lengths_K0_N_K1|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1|
Lengths_K0_N_K1|
ArrangeOrder| | | PerVector| PerVector_K1| |
_CCopy| _CCopy| _CCopy| | _CCopy| | _CCopy
|
// | | | | | | | | | | | | | | | | | |
| | | | |
| |
| | | | | | | | | |
|
|
|
|
|
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
1
,
32
,
1
,
8
,
8
>
;
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer
| BBlockLds| CShuffle| CShuffle
|
C
Block
TransferClusterLengths| CBlockTransf
er|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeate| _MBlock_MRepeat_MWaveMPerXdl| ScalarPer
Vector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle| PerShuffle| _NBlock_NRepeat_NWaveNPerXdl| _NWaveNPerXdl
|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// clang-format on
template
<
typename
TIn
,
...
...
example/6_conv2d_fwd_xdl_output_shuffle_bias_relu_add/conv2d_fwd_xdl_output_shuffle_bias_relu_add.cpp
View file @
925a8d78
...
...
@@ -33,11 +33,11 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd_v2;
// clang-format off
using
DeviceConvFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceConv2dFwdXdl_Output_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer|
ABlockTransfer|
ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer|
BBlockTransfer|
BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer|
B
Block
Lds| MRepeatPer| NRepeatPer| MRepeat| MThread| NRepeat| NThread| NScalarP
er|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per|
ThreadSlice|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
ThreadSlice|
ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
Shuffle| Shuffle| Thread| _CCopy| Thread| _CCopy|
Vector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave|
Lengths_K0_N_K1|
Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1|
Lengths_K0_N_K1|
ArrangeOrder| | | PerVector| PerVector_K1| |
_CCopy| _CCopy| _CCopy| | _CCopy| | _CCopy
|
// | | | | | | | | | | | | | | | | | |
| | | | |
| |
| | | | | | | | | |
|
|
|
|
|
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
1
,
4
,
8
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
1
,
32
,
1
,
8
,
8
>
;
// | InData| WeiData| OutData| AccData| In| Wei| Out| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer
| BBlockLds| CShuffle| CShuffle
|
C
Block
TransferClusterLengths| CBlockTransf
er|
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
MRepeate| NRepeate| _MBlock_MRepeat_MWaveMPerXdl| ScalarPer
Vector|
// | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
PerShuffle| PerShuffle| _NBlock_NRepeat_NWaveNPerXdl| _NWaveNPerXdl
|
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| |
<
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
8
>
;
// clang-format on
template
<
typename
TIn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment